From 751fbdc4a537205812c709a10e6d2fbcf37f19b1 Mon Sep 17 00:00:00 2001 From: marcsello Date: Tue, 6 Oct 2020 00:41:54 +0200 Subject: [PATCH] Implemented classifier cache --- .../classifier_cache.py | 76 +++++++++++++ cnn_classification_service/magic_doer.py | 102 ++++++------------ cnn_classification_service/main.py | 5 +- 3 files changed, 113 insertions(+), 70 deletions(-) diff --git a/cnn_classification_service/classifier_cache.py b/cnn_classification_service/classifier_cache.py index e69de29..b4f14e5 100644 --- a/cnn_classification_service/classifier_cache.py +++ b/cnn_classification_service/classifier_cache.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python3 +import os +import logging +import requests +import tempfile +from typing import Tuple +from urllib.parse import urljoin + +from cnn_classifier import Classifier + + +class ClassifierCache: + + def __init__(self, default_model_info_url: str = "http://model-service/model/cnn/$default"): + self._default_model_info_url = default_model_info_url + + self._current_model_details = None # Should never be equal to the default model id + self._current_classifier = None # Latest classifier is a classifier that uses the $default model + self._downloaded_files = [] + + self._session = requests.Session() + + def _cleanup(self): + self._current_classifier = None + self._current_model_details = None + for file in self._downloaded_files: + try: + os.unlink(file) + except FileNotFoundError: + pass + + self._downloaded_files = [] + + def _download_and_load_model(self, model_file_url: str, weights_file_url: str): + model_file_handle, model_file_path = tempfile.mkstemp(suffix=".json") + weights_file_handle, weights_file_path = tempfile.mkstemp(suffix=".h5") + + logging.debug("Fetching model file...") + r = self._session.get( # Fun fact: urljoin is used to support both relative and absolute urls + urljoin(self._default_model_info_url, model_file_url) + ) + r.raise_for_status() + + with open(model_file_handle, 'wb') as f: + f.write(r.content) + + self._downloaded_files.append(model_file_path) + + logging.debug("Fetching weights file...") + r = self._session.get( + urljoin(self._default_model_info_url, weights_file_url) + ) + r.raise_for_status() + + with open(weights_file_handle, 'wb') as f: + f.write(r.content) + + self._downloaded_files.append(weights_file_path) + + # magic happens here + self._current_classifier = Classifier(model_file_path, weights_file_path) + + def get_default_classifier(self) -> Tuple[dict, Classifier]: + logging.debug("Fetching model info...") + r = self._session.get(self._default_model_info_url) + r.raise_for_status() + + model_details = r.json() + + if (not self._current_model_details) or (self._current_model_details['id'] != model_details['id']): + # If the currently loaded model is not the default... then load it + self._cleanup() # delete/unload everything + self._download_and_load_model(model_details['files']['model'], model_details['files']['weights']) + self._current_model_details = model_details + + return self._current_model_details, self._current_classifier diff --git a/cnn_classification_service/magic_doer.py b/cnn_classification_service/magic_doer.py index a31b01b..6a2337c 100644 --- a/cnn_classification_service/magic_doer.py +++ b/cnn_classification_service/magic_doer.py @@ -4,80 +4,46 @@ import logging import tempfile import requests -from urllib.parse import urljoin - -from cnn_classifier import Classifier +from classifier_cache import ClassifierCache -def run_everything(parameters: dict): - tag = parameters['tag'] +class MagicDoer: + classifier_cache = ClassifierCache() - sample_file_handle, sample_file_path = tempfile.mkstemp(prefix=f"{tag}_", suffix=".wav") - model_file_handle, model_file_path = tempfile.mkstemp(suffix=".json") - weights_file_handle, weights_file_path = tempfile.mkstemp(suffix=".h5") - - try: - - # Download Sample - - logging.info(f"Downloading sample: {tag}") - r = requests.get(f"http://storage-service/object/{tag}") - with open(sample_file_handle, 'wb') as f: - f.write(r.content) - - logging.debug(f"Downloaded sample to {sample_file_path}") - - # Download model - - model_root_url = "http://model-service/model/cnn/$default" - - logging.debug("Fetching model info...") - r = requests.get(model_root_url) - r.raise_for_status() - - model_details = r.json() - - logging.debug("Fetching model file...") - r = requests.get(urljoin(model_root_url, model_details['files']['model'])) # Fun fact: this would support external urls - r.raise_for_status() - - with open(model_file_handle, 'wb') as f: - f.write(r.content) - - logging.debug("Fetching weights file...") - r = requests.get(urljoin(model_root_url, model_details['files']['weights'])) - r.raise_for_status() - - with open(weights_file_handle, 'wb') as f: - f.write(r.content) - - # magic happens here - classifier = Classifier(model_file_path, weights_file_path) - results = classifier.predict(sample_file_path) - - finally: # bruuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuh + @classmethod + def run_everything(cls, parameters: dict) -> dict: + tag = parameters['tag'] + sample_file_handle, sample_file_path = tempfile.mkstemp(prefix=f"{tag}_", suffix=".wav") try: - os.remove(model_file_path) - except FileNotFoundError: - pass - try: - os.remove(weights_file_path) - except FileNotFoundError: - pass + # Download Sample - try: - os.remove(sample_file_path) - except FileNotFoundError: - pass + logging.info(f"Downloading sample: {tag}") + r = requests.get(f"http://storage-service/object/{tag}") + with open(sample_file_handle, 'wb') as f: + f.write(r.content) - response = { - "tag": tag, - "probability": 1.0 if results[0] == model_details['target_class_name'] else 0.0, - "model": model_details['id'] - } + logging.debug(f"Downloaded sample to {sample_file_path}") - logging.info(f"Classification done!") - logging.debug(f"Results: {response}") + # Get a classifier that uses the default model + model_details, classifier = cls.classifier_cache.get_default_classifier() - return response + # do the majic + results = classifier.predict(sample_file_path) + + finally: + try: + os.remove(sample_file_path) + except FileNotFoundError: + pass + + response = { + "tag": tag, + "probability": 1.0 if results[0] == model_details['target_class_name'] else 0.0, + "model": model_details['id'] + } + + logging.info(f"Classification done!") + logging.debug(f"Results: {response}") + + return response diff --git a/cnn_classification_service/main.py b/cnn_classification_service/main.py index 5d0fdec..3b53f6e 100644 --- a/cnn_classification_service/main.py +++ b/cnn_classification_service/main.py @@ -8,7 +8,7 @@ import json from sentry_sdk.integrations.logging import LoggingIntegration import sentry_sdk -from magic_doer import run_everything +from magic_doer import MagicDoer def message_callback(channel, method, properties, body): @@ -18,7 +18,7 @@ def message_callback(channel, method, properties, body): logging.warning(f"Invalid message recieved: {e}") return - results = run_everything(msg) # <- This is where the magic happens + results = MagicDoer.run_everything(msg) # <- This is where the magic happens channel.basic_publish( exchange=os.environ['PIKA_OUTPUT_EXCHANGE'], @@ -57,6 +57,7 @@ def main(): queue_name = queue_declare_result.method.queue channel.queue_bind(exchange=os.environ['PIKA_INPUT_EXCHANGE'], queue=queue_name) + channel.basic_consume(queue=queue_name, on_message_callback=message_callback, auto_ack=True) logging.info("Connection complete! Listening to messages...")