Implemented classifier cache
All checks were successful
continuous-integration/drone/push Build is passing
All checks were successful
continuous-integration/drone/push Build is passing
This commit is contained in:
parent
01fa54e6b6
commit
751fbdc4a5
@ -0,0 +1,76 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import requests
|
||||||
|
import tempfile
|
||||||
|
from typing import Tuple
|
||||||
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
|
from cnn_classifier import Classifier
|
||||||
|
|
||||||
|
|
||||||
|
class ClassifierCache:
|
||||||
|
|
||||||
|
def __init__(self, default_model_info_url: str = "http://model-service/model/cnn/$default"):
|
||||||
|
self._default_model_info_url = default_model_info_url
|
||||||
|
|
||||||
|
self._current_model_details = None # Should never be equal to the default model id
|
||||||
|
self._current_classifier = None # Latest classifier is a classifier that uses the $default model
|
||||||
|
self._downloaded_files = []
|
||||||
|
|
||||||
|
self._session = requests.Session()
|
||||||
|
|
||||||
|
def _cleanup(self):
|
||||||
|
self._current_classifier = None
|
||||||
|
self._current_model_details = None
|
||||||
|
for file in self._downloaded_files:
|
||||||
|
try:
|
||||||
|
os.unlink(file)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self._downloaded_files = []
|
||||||
|
|
||||||
|
def _download_and_load_model(self, model_file_url: str, weights_file_url: str):
|
||||||
|
model_file_handle, model_file_path = tempfile.mkstemp(suffix=".json")
|
||||||
|
weights_file_handle, weights_file_path = tempfile.mkstemp(suffix=".h5")
|
||||||
|
|
||||||
|
logging.debug("Fetching model file...")
|
||||||
|
r = self._session.get( # Fun fact: urljoin is used to support both relative and absolute urls
|
||||||
|
urljoin(self._default_model_info_url, model_file_url)
|
||||||
|
)
|
||||||
|
r.raise_for_status()
|
||||||
|
|
||||||
|
with open(model_file_handle, 'wb') as f:
|
||||||
|
f.write(r.content)
|
||||||
|
|
||||||
|
self._downloaded_files.append(model_file_path)
|
||||||
|
|
||||||
|
logging.debug("Fetching weights file...")
|
||||||
|
r = self._session.get(
|
||||||
|
urljoin(self._default_model_info_url, weights_file_url)
|
||||||
|
)
|
||||||
|
r.raise_for_status()
|
||||||
|
|
||||||
|
with open(weights_file_handle, 'wb') as f:
|
||||||
|
f.write(r.content)
|
||||||
|
|
||||||
|
self._downloaded_files.append(weights_file_path)
|
||||||
|
|
||||||
|
# magic happens here
|
||||||
|
self._current_classifier = Classifier(model_file_path, weights_file_path)
|
||||||
|
|
||||||
|
def get_default_classifier(self) -> Tuple[dict, Classifier]:
|
||||||
|
logging.debug("Fetching model info...")
|
||||||
|
r = self._session.get(self._default_model_info_url)
|
||||||
|
r.raise_for_status()
|
||||||
|
|
||||||
|
model_details = r.json()
|
||||||
|
|
||||||
|
if (not self._current_model_details) or (self._current_model_details['id'] != model_details['id']):
|
||||||
|
# If the currently loaded model is not the default... then load it
|
||||||
|
self._cleanup() # delete/unload everything
|
||||||
|
self._download_and_load_model(model_details['files']['model'], model_details['files']['weights'])
|
||||||
|
self._current_model_details = model_details
|
||||||
|
|
||||||
|
return self._current_model_details, self._current_classifier
|
@ -4,18 +4,16 @@ import logging
|
|||||||
import tempfile
|
import tempfile
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from urllib.parse import urljoin
|
from classifier_cache import ClassifierCache
|
||||||
|
|
||||||
from cnn_classifier import Classifier
|
|
||||||
|
|
||||||
|
|
||||||
def run_everything(parameters: dict):
|
class MagicDoer:
|
||||||
|
classifier_cache = ClassifierCache()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def run_everything(cls, parameters: dict) -> dict:
|
||||||
tag = parameters['tag']
|
tag = parameters['tag']
|
||||||
|
|
||||||
sample_file_handle, sample_file_path = tempfile.mkstemp(prefix=f"{tag}_", suffix=".wav")
|
sample_file_handle, sample_file_path = tempfile.mkstemp(prefix=f"{tag}_", suffix=".wav")
|
||||||
model_file_handle, model_file_path = tempfile.mkstemp(suffix=".json")
|
|
||||||
weights_file_handle, weights_file_path = tempfile.mkstemp(suffix=".h5")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
# Download Sample
|
# Download Sample
|
||||||
@ -27,45 +25,13 @@ def run_everything(parameters: dict):
|
|||||||
|
|
||||||
logging.debug(f"Downloaded sample to {sample_file_path}")
|
logging.debug(f"Downloaded sample to {sample_file_path}")
|
||||||
|
|
||||||
# Download model
|
# Get a classifier that uses the default model
|
||||||
|
model_details, classifier = cls.classifier_cache.get_default_classifier()
|
||||||
|
|
||||||
model_root_url = "http://model-service/model/cnn/$default"
|
# do the majic
|
||||||
|
|
||||||
logging.debug("Fetching model info...")
|
|
||||||
r = requests.get(model_root_url)
|
|
||||||
r.raise_for_status()
|
|
||||||
|
|
||||||
model_details = r.json()
|
|
||||||
|
|
||||||
logging.debug("Fetching model file...")
|
|
||||||
r = requests.get(urljoin(model_root_url, model_details['files']['model'])) # Fun fact: this would support external urls
|
|
||||||
r.raise_for_status()
|
|
||||||
|
|
||||||
with open(model_file_handle, 'wb') as f:
|
|
||||||
f.write(r.content)
|
|
||||||
|
|
||||||
logging.debug("Fetching weights file...")
|
|
||||||
r = requests.get(urljoin(model_root_url, model_details['files']['weights']))
|
|
||||||
r.raise_for_status()
|
|
||||||
|
|
||||||
with open(weights_file_handle, 'wb') as f:
|
|
||||||
f.write(r.content)
|
|
||||||
|
|
||||||
# magic happens here
|
|
||||||
classifier = Classifier(model_file_path, weights_file_path)
|
|
||||||
results = classifier.predict(sample_file_path)
|
results = classifier.predict(sample_file_path)
|
||||||
|
|
||||||
finally: # bruuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuh
|
finally:
|
||||||
try:
|
|
||||||
os.remove(model_file_path)
|
|
||||||
except FileNotFoundError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
os.remove(weights_file_path)
|
|
||||||
except FileNotFoundError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
os.remove(sample_file_path)
|
os.remove(sample_file_path)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
|
@ -8,7 +8,7 @@ import json
|
|||||||
from sentry_sdk.integrations.logging import LoggingIntegration
|
from sentry_sdk.integrations.logging import LoggingIntegration
|
||||||
import sentry_sdk
|
import sentry_sdk
|
||||||
|
|
||||||
from magic_doer import run_everything
|
from magic_doer import MagicDoer
|
||||||
|
|
||||||
|
|
||||||
def message_callback(channel, method, properties, body):
|
def message_callback(channel, method, properties, body):
|
||||||
@ -18,7 +18,7 @@ def message_callback(channel, method, properties, body):
|
|||||||
logging.warning(f"Invalid message recieved: {e}")
|
logging.warning(f"Invalid message recieved: {e}")
|
||||||
return
|
return
|
||||||
|
|
||||||
results = run_everything(msg) # <- This is where the magic happens
|
results = MagicDoer.run_everything(msg) # <- This is where the magic happens
|
||||||
|
|
||||||
channel.basic_publish(
|
channel.basic_publish(
|
||||||
exchange=os.environ['PIKA_OUTPUT_EXCHANGE'],
|
exchange=os.environ['PIKA_OUTPUT_EXCHANGE'],
|
||||||
@ -57,6 +57,7 @@ def main():
|
|||||||
queue_name = queue_declare_result.method.queue
|
queue_name = queue_declare_result.method.queue
|
||||||
|
|
||||||
channel.queue_bind(exchange=os.environ['PIKA_INPUT_EXCHANGE'], queue=queue_name)
|
channel.queue_bind(exchange=os.environ['PIKA_INPUT_EXCHANGE'], queue=queue_name)
|
||||||
|
|
||||||
channel.basic_consume(queue=queue_name, on_message_callback=message_callback, auto_ack=True)
|
channel.basic_consume(queue=queue_name, on_message_callback=message_callback, auto_ack=True)
|
||||||
|
|
||||||
logging.info("Connection complete! Listening to messages...")
|
logging.info("Connection complete! Listening to messages...")
|
||||||
|
Loading…
Reference in New Issue
Block a user