#!/usr/bin/env python3 import sentry_sdk import os import requests import json import uwsgi import pickle from urllib.parse import urljoin from config import Config from pyAudioAnalysis.audioTrainTest import load_model, load_model_knn, classifier_wrapper from pyAudioAnalysis import audioBasicIO from pyAudioAnalysis import MidTermFeatures import numpy if Config.SENTRY_DSN: sentry_sdk.init( dsn=Config.SENTRY_DSN, send_default_pii=True, release=Config.RELEASE_ID, environment=Config.RELEASEMODE ) def json_datetime_dumper(o): if hasattr(o, "isoformat"): return o.isoformat() else: return str(o) class ModelMemer: def __init__(self): self._loaded_model = None def download_model_if_needed(self): models_dir = "/tmp/svm_model" os.makedirs(models_dir, exist_ok=True) model_file = os.path.join(models_dir, "model") means_file = os.path.join(models_dir, "modelMEANS") if os.path.isfile(model_file) and self._loaded_model: return r = requests.get(Config.MODEL_INFO_URL) r.raise_for_status() self.model_details = r.json() r = requests.get(urljoin(Config.MODEL_INFO_URL, self.model_details['files']['model'])) r.raise_for_status() with open(model_file, 'wb') as f: f.write(r.content) r = requests.get(urljoin(Config.MODEL_INFO_URL, self.model_details['files']['means'])) r.raise_for_status() with open(means_file, 'wb') as f: f.write(r.content) if self.model_details['type'] == 'knn': self.classifier, self.mean, self.std, self.classes, self.mid_window, self.mid_step, self.short_window, \ self.short_step, self.compute_beat \ = load_model_knn(model_file) else: self.classifier, self.mean, self.std, self.classes, self.mid_window, self.mid_step, self.short_window, \ self.short_step, self.compute_beat \ = load_model(model_file) target_class_name = self.model_details['target_class_name'] self.target_id = self.classes.index(target_class_name) def run_classification(audio_file_path: str, memer: ModelMemer): memer.download_model_if_needed() # run extraction sampling_rate, signal = audioBasicIO.read_audio_file(audio_file_path) signal = audioBasicIO.stereo_to_mono(signal) if sampling_rate == 0: raise AssertionError("Could not read the file properly: Sampling rate zero") if signal.shape[0] / float(sampling_rate) <= memer.mid_window: raise AssertionError("Could not read the file properly: Signal shape is not good") # feature extraction: mid_features, s, _ = \ MidTermFeatures.mid_feature_extraction(signal, sampling_rate, memer.mid_window * sampling_rate, memer.mid_step * sampling_rate, round(sampling_rate * memer.short_window), round(sampling_rate * memer.short_step)) # long term averaging of mid-term statistics mid_features = mid_features.mean(axis=1) if memer.compute_beat: beat, beat_conf = MidTermFeatures.beat_extraction(s, memer.short_step) mid_features = numpy.append(mid_features, beat) mid_features = numpy.append(mid_features, beat_conf) feature_vector = (mid_features - memer.mean) / memer.std class_id, probability = classifier_wrapper( memer.classifier, memer.model_details['type'].lower(), feature_vector ) class_id = int(class_id) # faszom return bool((class_id == memer.target_id) and (probability[class_id] > 0.5)) def main(): memer = ModelMemer() while True: message = uwsgi.mule_get_msg() task = pickle.loads(message) audio_file_path = task['audio_file_path'] description = task['description'] try: result = run_classification(audio_file_path, memer) if result: # upload to real input service files = { "file": ( os.path.basename(audio_file_path), open(audio_file_path, 'rb').read(), 'audio/wave', {'Content-length': os.path.getsize(audio_file_path)} ), "description": (None, json.dumps(description, default=json_datetime_dumper), "application/json") } r = requests.post(Config.INPUT_SERVICE_URL, files=files) r.raise_for_status() finally: os.remove(audio_file_path) if __name__ == '__main__': main()