#!/usr/bin/env python3 import sentry_sdk import os import requests import tempfile import numpy import json import pika import uwsgi from pyAudioAnalysis.audioTrainTest import load_model, load_model_knn, classifier_wrapper SENTRY_DSN = os.environ.get("SENTRY_DSN") if SENTRY_DSN: sentry_sdk.init( dsn=SENTRY_DSN, send_default_pii=True, release=os.environ.get('RELEASE_ID', 'test'), environment=os.environ.get('RELEASEMODE', 'dev') ) def run_classification(task, target_class_name: str): _, temp_model_name = tempfile.mkstemp() temp_means_name = temp_model_name + "MEANS" r = requests.get(f"http://model-service/model/{task['model']}/details") r.raise_for_status() model_details = r.json() try: r = requests.get(f"http://model-service/model/{task['model']}") r.raise_for_status() with open(temp_model_name, 'wb') as f: f.write(r.content) r = requests.get(f"http://model-service/model/{task['model']}?means") r.raise_for_status() with open(temp_means_name, 'wb') as f: f.write(r.content) if model_details['type'] == 'knn': classifier, mean, std, classes, mid_window, mid_step, short_window, short_step, compute_beat \ = load_model_knn(temp_model_name) else: classifier, mean, std, classes, mid_window, mid_step, short_window, short_step, compute_beat \ = load_model(temp_model_name) target_id = classes.index(target_class_name) # Might raise ValueError feature_vector = (numpy.array(task['features']) - mean) / std class_id, probability = classifier_wrapper(classifier, model_details['type'], feature_vector) finally: # bruuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuh try: os.remove(temp_model_name) except FileNotFoundError: pass try: os.remove(temp_means_name) except FileNotFoundError: pass results = { "tag": task['tag'], "model": task['model'], "is_target": bool(class_id == target_id), "probability": probability[target_id] } return results def main(): connection = pika.BlockingConnection(pika.connection.URLParameters(os.environ['PIKA_URL'])) channel = connection.channel() channel.exchange_declare(exchange=os.environ['PIKA_EXCHANGE'], exchange_type='fanout') while True: message = uwsgi.mule_get_msg() task = json.loads(message) results = run_classification(task, os.environ['TARGET_CLASS_NAME']) channel.basic_publish(exchange=os.environ['PIKA_EXCHANGE'], routing_key='classification-result', body=json.dumps(results).encode("utf-8")) if __name__ == '__main__': main()