Implemented classifier cache
All checks were successful
continuous-integration/drone/push Build is passing
All checks were successful
continuous-integration/drone/push Build is passing
This commit is contained in:
@@ -4,80 +4,46 @@ import logging
|
||||
import tempfile
|
||||
import requests
|
||||
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from cnn_classifier import Classifier
|
||||
from classifier_cache import ClassifierCache
|
||||
|
||||
|
||||
def run_everything(parameters: dict):
|
||||
tag = parameters['tag']
|
||||
class MagicDoer:
|
||||
classifier_cache = ClassifierCache()
|
||||
|
||||
sample_file_handle, sample_file_path = tempfile.mkstemp(prefix=f"{tag}_", suffix=".wav")
|
||||
model_file_handle, model_file_path = tempfile.mkstemp(suffix=".json")
|
||||
weights_file_handle, weights_file_path = tempfile.mkstemp(suffix=".h5")
|
||||
|
||||
try:
|
||||
|
||||
# Download Sample
|
||||
|
||||
logging.info(f"Downloading sample: {tag}")
|
||||
r = requests.get(f"http://storage-service/object/{tag}")
|
||||
with open(sample_file_handle, 'wb') as f:
|
||||
f.write(r.content)
|
||||
|
||||
logging.debug(f"Downloaded sample to {sample_file_path}")
|
||||
|
||||
# Download model
|
||||
|
||||
model_root_url = "http://model-service/model/cnn/$default"
|
||||
|
||||
logging.debug("Fetching model info...")
|
||||
r = requests.get(model_root_url)
|
||||
r.raise_for_status()
|
||||
|
||||
model_details = r.json()
|
||||
|
||||
logging.debug("Fetching model file...")
|
||||
r = requests.get(urljoin(model_root_url, model_details['files']['model'])) # Fun fact: this would support external urls
|
||||
r.raise_for_status()
|
||||
|
||||
with open(model_file_handle, 'wb') as f:
|
||||
f.write(r.content)
|
||||
|
||||
logging.debug("Fetching weights file...")
|
||||
r = requests.get(urljoin(model_root_url, model_details['files']['weights']))
|
||||
r.raise_for_status()
|
||||
|
||||
with open(weights_file_handle, 'wb') as f:
|
||||
f.write(r.content)
|
||||
|
||||
# magic happens here
|
||||
classifier = Classifier(model_file_path, weights_file_path)
|
||||
results = classifier.predict(sample_file_path)
|
||||
|
||||
finally: # bruuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuh
|
||||
@classmethod
|
||||
def run_everything(cls, parameters: dict) -> dict:
|
||||
tag = parameters['tag']
|
||||
sample_file_handle, sample_file_path = tempfile.mkstemp(prefix=f"{tag}_", suffix=".wav")
|
||||
try:
|
||||
os.remove(model_file_path)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
try:
|
||||
os.remove(weights_file_path)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
# Download Sample
|
||||
|
||||
try:
|
||||
os.remove(sample_file_path)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
logging.info(f"Downloading sample: {tag}")
|
||||
r = requests.get(f"http://storage-service/object/{tag}")
|
||||
with open(sample_file_handle, 'wb') as f:
|
||||
f.write(r.content)
|
||||
|
||||
response = {
|
||||
"tag": tag,
|
||||
"probability": 1.0 if results[0] == model_details['target_class_name'] else 0.0,
|
||||
"model": model_details['id']
|
||||
}
|
||||
logging.debug(f"Downloaded sample to {sample_file_path}")
|
||||
|
||||
logging.info(f"Classification done!")
|
||||
logging.debug(f"Results: {response}")
|
||||
# Get a classifier that uses the default model
|
||||
model_details, classifier = cls.classifier_cache.get_default_classifier()
|
||||
|
||||
return response
|
||||
# do the majic
|
||||
results = classifier.predict(sample_file_path)
|
||||
|
||||
finally:
|
||||
try:
|
||||
os.remove(sample_file_path)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
response = {
|
||||
"tag": tag,
|
||||
"probability": 1.0 if results[0] == model_details['target_class_name'] else 0.0,
|
||||
"model": model_details['id']
|
||||
}
|
||||
|
||||
logging.info(f"Classification done!")
|
||||
logging.debug(f"Results: {response}")
|
||||
|
||||
return response
|
||||
|
||||
Reference in New Issue
Block a user