svm-prefilter-service/svm_prefilter_service/mule.py

144 lines
4.7 KiB
Python

#!/usr/bin/env python3
import sentry_sdk
import os
import requests
import json
import uwsgi
import pickle
from urllib.parse import urljoin
from config import Config
from pyAudioAnalysis.audioTrainTest import load_model, load_model_knn, classifier_wrapper
from pyAudioAnalysis import audioBasicIO
from pyAudioAnalysis import MidTermFeatures
import numpy
if Config.SENTRY_DSN:
sentry_sdk.init(
dsn=Config.SENTRY_DSN,
send_default_pii=True,
release=Config.RELEASE_ID,
environment=Config.RELEASEMODE
)
def json_datetime_dumper(o):
if hasattr(o, "isoformat"):
return o.isoformat()
else:
return str(o)
class ModelMemer:
def __init__(self):
self._loaded_model = None
def download_model_if_needed(self):
models_dir = "/tmp/svm_model"
os.makedirs(models_dir, exist_ok=True)
model_file = os.path.join(models_dir, "model")
means_file = os.path.join(models_dir, "modelMEANS")
if os.path.isfile(model_file) and self._loaded_model:
return
r = requests.get(Config.MODEL_INFO_URL)
r.raise_for_status()
self.model_details = r.json()
r = requests.get(urljoin(Config.MODEL_INFO_URL, self.model_details['files']['model']))
r.raise_for_status()
with open(model_file, 'wb') as f:
f.write(r.content)
r = requests.get(urljoin(Config.MODEL_INFO_URL, self.model_details['files']['means']))
r.raise_for_status()
with open(means_file, 'wb') as f:
f.write(r.content)
if self.model_details['type'] == 'knn':
self.classifier, self.mean, self.std, self.classes, self.mid_window, self.mid_step, self.short_window, \
self.short_step, self.compute_beat \
= load_model_knn(model_file)
else:
self.classifier, self.mean, self.std, self.classes, self.mid_window, self.mid_step, self.short_window, \
self.short_step, self.compute_beat \
= load_model(model_file)
target_class_name = self.model_details['target_class_name']
self.target_id = self.classes.index(target_class_name)
def run_classification(audio_file_path: str, memer: ModelMemer):
memer.download_model_if_needed()
# run extraction
sampling_rate, signal = audioBasicIO.read_audio_file(audio_file_path)
signal = audioBasicIO.stereo_to_mono(signal)
if sampling_rate == 0:
raise AssertionError("Could not read the file properly: Sampling rate zero")
if signal.shape[0] / float(sampling_rate) <= memer.mid_window:
raise AssertionError("Could not read the file properly: Signal shape is not good")
# feature extraction:
mid_features, s, _ = \
MidTermFeatures.mid_feature_extraction(signal, sampling_rate,
memer.mid_window * sampling_rate,
memer.mid_step * sampling_rate,
round(sampling_rate * memer.short_window),
round(sampling_rate * memer.short_step))
# long term averaging of mid-term statistics
mid_features = mid_features.mean(axis=1)
if memer.compute_beat:
beat, beat_conf = MidTermFeatures.beat_extraction(s, memer.short_step)
mid_features = numpy.append(mid_features, beat)
mid_features = numpy.append(mid_features, beat_conf)
feature_vector = (mid_features - memer.mean) / memer.std
class_id, probability = classifier_wrapper(
memer.classifier, memer.model_details['type'].lower(), feature_vector
)
class_id = int(class_id) # faszom
return bool((class_id == memer.target_id) and (probability[class_id] > 0.5))
def main():
memer = ModelMemer()
while True:
message = uwsgi.mule_get_msg()
task = pickle.loads(message)
audio_file_path = task['audio_file_path']
description = task['description']
try:
result = run_classification(audio_file_path, memer)
if result:
# upload to real input service
files = {
"file": (
os.path.basename(audio_file_path),
open(audio_file_path, 'rb').read(),
'audio/wave',
{'Content-length': os.path.getsize(audio_file_path)}
),
"description": (None, json.dumps(description, default=json_datetime_dumper), "application/json")
}
r = requests.post(Config.INPUT_SERVICE_URL, files=files)
r.raise_for_status()
finally:
os.remove(audio_file_path)
if __name__ == '__main__':
main()