Implemented classifier cache
	
		
			
	
		
	
	
		
	
		
			All checks were successful
		
		
	
	
		
			
				
	
				continuous-integration/drone/push Build is passing
				
			
		
		
	
	
				
					
				
			
		
			All checks were successful
		
		
	
	continuous-integration/drone/push Build is passing
				
			This commit is contained in:
		@@ -0,0 +1,76 @@
 | 
			
		||||
#!/usr/bin/env python3
 | 
			
		||||
import os
 | 
			
		||||
import logging
 | 
			
		||||
import requests
 | 
			
		||||
import tempfile
 | 
			
		||||
from typing import Tuple
 | 
			
		||||
from urllib.parse import urljoin
 | 
			
		||||
 | 
			
		||||
from cnn_classifier import Classifier
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ClassifierCache:
 | 
			
		||||
 | 
			
		||||
    def __init__(self, default_model_info_url: str = "http://model-service/model/cnn/$default"):
 | 
			
		||||
        self._default_model_info_url = default_model_info_url
 | 
			
		||||
 | 
			
		||||
        self._current_model_details = None  # Should never be equal to the default model id
 | 
			
		||||
        self._current_classifier = None  # Latest classifier is a classifier that uses the $default model
 | 
			
		||||
        self._downloaded_files = []
 | 
			
		||||
 | 
			
		||||
        self._session = requests.Session()
 | 
			
		||||
 | 
			
		||||
    def _cleanup(self):
 | 
			
		||||
        self._current_classifier = None
 | 
			
		||||
        self._current_model_details = None
 | 
			
		||||
        for file in self._downloaded_files:
 | 
			
		||||
            try:
 | 
			
		||||
                os.unlink(file)
 | 
			
		||||
            except FileNotFoundError:
 | 
			
		||||
                pass
 | 
			
		||||
 | 
			
		||||
        self._downloaded_files = []
 | 
			
		||||
 | 
			
		||||
    def _download_and_load_model(self, model_file_url: str, weights_file_url: str):
 | 
			
		||||
        model_file_handle, model_file_path = tempfile.mkstemp(suffix=".json")
 | 
			
		||||
        weights_file_handle, weights_file_path = tempfile.mkstemp(suffix=".h5")
 | 
			
		||||
 | 
			
		||||
        logging.debug("Fetching model file...")
 | 
			
		||||
        r = self._session.get(  # Fun fact: urljoin is used to support both relative and absolute urls
 | 
			
		||||
            urljoin(self._default_model_info_url, model_file_url)
 | 
			
		||||
        )
 | 
			
		||||
        r.raise_for_status()
 | 
			
		||||
 | 
			
		||||
        with open(model_file_handle, 'wb') as f:
 | 
			
		||||
            f.write(r.content)
 | 
			
		||||
 | 
			
		||||
        self._downloaded_files.append(model_file_path)
 | 
			
		||||
 | 
			
		||||
        logging.debug("Fetching weights file...")
 | 
			
		||||
        r = self._session.get(
 | 
			
		||||
            urljoin(self._default_model_info_url, weights_file_url)
 | 
			
		||||
        )
 | 
			
		||||
        r.raise_for_status()
 | 
			
		||||
 | 
			
		||||
        with open(weights_file_handle, 'wb') as f:
 | 
			
		||||
            f.write(r.content)
 | 
			
		||||
 | 
			
		||||
        self._downloaded_files.append(weights_file_path)
 | 
			
		||||
 | 
			
		||||
        # magic happens here
 | 
			
		||||
        self._current_classifier = Classifier(model_file_path, weights_file_path)
 | 
			
		||||
 | 
			
		||||
    def get_default_classifier(self) -> Tuple[dict, Classifier]:
 | 
			
		||||
        logging.debug("Fetching model info...")
 | 
			
		||||
        r = self._session.get(self._default_model_info_url)
 | 
			
		||||
        r.raise_for_status()
 | 
			
		||||
 | 
			
		||||
        model_details = r.json()
 | 
			
		||||
 | 
			
		||||
        if (not self._current_model_details) or (self._current_model_details['id'] != model_details['id']):
 | 
			
		||||
            # If the currently loaded model is not the default... then load it
 | 
			
		||||
            self._cleanup()  # delete/unload everything
 | 
			
		||||
            self._download_and_load_model(model_details['files']['model'], model_details['files']['weights'])
 | 
			
		||||
            self._current_model_details = model_details
 | 
			
		||||
 | 
			
		||||
        return self._current_model_details, self._current_classifier
 | 
			
		||||
 
 | 
			
		||||
@@ -4,80 +4,46 @@ import logging
 | 
			
		||||
import tempfile
 | 
			
		||||
import requests
 | 
			
		||||
 | 
			
		||||
from urllib.parse import urljoin
 | 
			
		||||
 | 
			
		||||
from cnn_classifier import Classifier
 | 
			
		||||
from classifier_cache import ClassifierCache
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_everything(parameters: dict):
 | 
			
		||||
    tag = parameters['tag']
 | 
			
		||||
class MagicDoer:
 | 
			
		||||
    classifier_cache = ClassifierCache()
 | 
			
		||||
 | 
			
		||||
    sample_file_handle, sample_file_path = tempfile.mkstemp(prefix=f"{tag}_", suffix=".wav")
 | 
			
		||||
    model_file_handle, model_file_path = tempfile.mkstemp(suffix=".json")
 | 
			
		||||
    weights_file_handle, weights_file_path = tempfile.mkstemp(suffix=".h5")
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
 | 
			
		||||
        # Download Sample
 | 
			
		||||
 | 
			
		||||
        logging.info(f"Downloading sample: {tag}")
 | 
			
		||||
        r = requests.get(f"http://storage-service/object/{tag}")
 | 
			
		||||
        with open(sample_file_handle, 'wb') as f:
 | 
			
		||||
            f.write(r.content)
 | 
			
		||||
 | 
			
		||||
        logging.debug(f"Downloaded sample to {sample_file_path}")
 | 
			
		||||
 | 
			
		||||
        # Download model
 | 
			
		||||
 | 
			
		||||
        model_root_url = "http://model-service/model/cnn/$default"
 | 
			
		||||
 | 
			
		||||
        logging.debug("Fetching model info...")
 | 
			
		||||
        r = requests.get(model_root_url)
 | 
			
		||||
        r.raise_for_status()
 | 
			
		||||
 | 
			
		||||
        model_details = r.json()
 | 
			
		||||
 | 
			
		||||
        logging.debug("Fetching model file...")
 | 
			
		||||
        r = requests.get(urljoin(model_root_url, model_details['files']['model']))  # Fun fact: this would support external urls
 | 
			
		||||
        r.raise_for_status()
 | 
			
		||||
 | 
			
		||||
        with open(model_file_handle, 'wb') as f:
 | 
			
		||||
            f.write(r.content)
 | 
			
		||||
 | 
			
		||||
        logging.debug("Fetching weights file...")
 | 
			
		||||
        r = requests.get(urljoin(model_root_url, model_details['files']['weights']))
 | 
			
		||||
        r.raise_for_status()
 | 
			
		||||
 | 
			
		||||
        with open(weights_file_handle, 'wb') as f:
 | 
			
		||||
            f.write(r.content)
 | 
			
		||||
 | 
			
		||||
        # magic happens here
 | 
			
		||||
        classifier = Classifier(model_file_path, weights_file_path)
 | 
			
		||||
        results = classifier.predict(sample_file_path)
 | 
			
		||||
 | 
			
		||||
    finally:  # bruuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuh
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def run_everything(cls, parameters: dict) -> dict:
 | 
			
		||||
        tag = parameters['tag']
 | 
			
		||||
        sample_file_handle, sample_file_path = tempfile.mkstemp(prefix=f"{tag}_", suffix=".wav")
 | 
			
		||||
        try:
 | 
			
		||||
            os.remove(model_file_path)
 | 
			
		||||
        except FileNotFoundError:
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            os.remove(weights_file_path)
 | 
			
		||||
        except FileNotFoundError:
 | 
			
		||||
            pass
 | 
			
		||||
            # Download Sample
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            os.remove(sample_file_path)
 | 
			
		||||
        except FileNotFoundError:
 | 
			
		||||
            pass
 | 
			
		||||
            logging.info(f"Downloading sample: {tag}")
 | 
			
		||||
            r = requests.get(f"http://storage-service/object/{tag}")
 | 
			
		||||
            with open(sample_file_handle, 'wb') as f:
 | 
			
		||||
                f.write(r.content)
 | 
			
		||||
 | 
			
		||||
    response = {
 | 
			
		||||
        "tag": tag,
 | 
			
		||||
        "probability": 1.0 if results[0] == model_details['target_class_name'] else 0.0,
 | 
			
		||||
        "model": model_details['id']
 | 
			
		||||
    }
 | 
			
		||||
            logging.debug(f"Downloaded sample to {sample_file_path}")
 | 
			
		||||
 | 
			
		||||
    logging.info(f"Classification done!")
 | 
			
		||||
    logging.debug(f"Results: {response}")
 | 
			
		||||
            # Get a classifier that uses the default model
 | 
			
		||||
            model_details, classifier = cls.classifier_cache.get_default_classifier()
 | 
			
		||||
 | 
			
		||||
    return response
 | 
			
		||||
            # do the majic
 | 
			
		||||
            results = classifier.predict(sample_file_path)
 | 
			
		||||
 | 
			
		||||
        finally:
 | 
			
		||||
            try:
 | 
			
		||||
                os.remove(sample_file_path)
 | 
			
		||||
            except FileNotFoundError:
 | 
			
		||||
                pass
 | 
			
		||||
 | 
			
		||||
        response = {
 | 
			
		||||
            "tag": tag,
 | 
			
		||||
            "probability": 1.0 if results[0] == model_details['target_class_name'] else 0.0,
 | 
			
		||||
            "model": model_details['id']
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        logging.info(f"Classification done!")
 | 
			
		||||
        logging.debug(f"Results: {response}")
 | 
			
		||||
 | 
			
		||||
        return response
 | 
			
		||||
 
 | 
			
		||||
@@ -8,7 +8,7 @@ import json
 | 
			
		||||
from sentry_sdk.integrations.logging import LoggingIntegration
 | 
			
		||||
import sentry_sdk
 | 
			
		||||
 | 
			
		||||
from magic_doer import run_everything
 | 
			
		||||
from magic_doer import MagicDoer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def message_callback(channel, method, properties, body):
 | 
			
		||||
@@ -18,7 +18,7 @@ def message_callback(channel, method, properties, body):
 | 
			
		||||
        logging.warning(f"Invalid message recieved: {e}")
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    results = run_everything(msg)  # <- This is where the magic happens
 | 
			
		||||
    results = MagicDoer.run_everything(msg)  # <- This is where the magic happens
 | 
			
		||||
 | 
			
		||||
    channel.basic_publish(
 | 
			
		||||
        exchange=os.environ['PIKA_OUTPUT_EXCHANGE'],
 | 
			
		||||
@@ -57,6 +57,7 @@ def main():
 | 
			
		||||
    queue_name = queue_declare_result.method.queue
 | 
			
		||||
 | 
			
		||||
    channel.queue_bind(exchange=os.environ['PIKA_INPUT_EXCHANGE'], queue=queue_name)
 | 
			
		||||
 | 
			
		||||
    channel.basic_consume(queue=queue_name, on_message_callback=message_callback, auto_ack=True)
 | 
			
		||||
 | 
			
		||||
    logging.info("Connection complete! Listening to messages...")
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user