svm-prefilter-service/svm_prefilter_service/mule.py

185 lines
5.7 KiB
Python

#!/usr/bin/env python3
import sentry_sdk
import os
import requests
import json
import uwsgi
import pickle
from threading import Thread
from queue import Queue
from urllib.parse import urljoin
from config import Config
from pyAudioAnalysis.audioTrainTest import load_model, load_model_knn, classifier_wrapper
from pyAudioAnalysis import audioBasicIO
from pyAudioAnalysis import MidTermFeatures
import numpy
from apscheduler.schedulers.background import BackgroundScheduler
if Config.SENTRY_DSN:
sentry_sdk.init(
dsn=Config.SENTRY_DSN,
send_default_pii=True,
release=Config.RELEASE_ID,
environment=Config.RELEASEMODE
)
def json_datetime_dumper(o):
if hasattr(o, "isoformat"):
return o.isoformat()
else:
return str(o)
class ModelMemer:
def __init__(self):
self._loaded_model = None
def download_model_if_needed(self):
models_dir = "/tmp/svm_model"
os.makedirs(models_dir, exist_ok=True)
model_file = os.path.join(models_dir, "model")
means_file = os.path.join(models_dir, "modelMEANS")
if os.path.isfile(model_file) and self._loaded_model:
return
r = requests.get(Config.MODEL_INFO_URL)
r.raise_for_status()
self.model_details = r.json()
r = requests.get(urljoin(Config.MODEL_INFO_URL, self.model_details['files']['model']))
r.raise_for_status()
with open(model_file, 'wb') as f:
f.write(r.content)
r = requests.get(urljoin(Config.MODEL_INFO_URL, self.model_details['files']['means']))
r.raise_for_status()
with open(means_file, 'wb') as f:
f.write(r.content)
if self.model_details['type'] == 'knn':
self.classifier, self.mean, self.std, self.classes, self.mid_window, self.mid_step, self.short_window, \
self.short_step, self.compute_beat \
= load_model_knn(model_file)
else:
self.classifier, self.mean, self.std, self.classes, self.mid_window, self.mid_step, self.short_window, \
self.short_step, self.compute_beat \
= load_model(model_file)
target_class_name = self.model_details['target_class_name']
self.target_id = self.classes.index(target_class_name)
def run_classification(audio_file_path: str, memer: ModelMemer):
memer.download_model_if_needed()
# run extraction
sampling_rate, signal = audioBasicIO.read_audio_file(audio_file_path)
signal = audioBasicIO.stereo_to_mono(signal)
if sampling_rate == 0:
raise AssertionError("Could not read the file properly: Sampling rate zero")
if signal.shape[0] / float(sampling_rate) <= memer.mid_window:
raise AssertionError("Could not read the file properly: Signal shape is not good")
# feature extraction:
mid_features, s, _ = \
MidTermFeatures.mid_feature_extraction(signal, sampling_rate,
memer.mid_window * sampling_rate,
memer.mid_step * sampling_rate,
round(sampling_rate * memer.short_window),
round(sampling_rate * memer.short_step))
# long term averaging of mid-term statistics
mid_features = mid_features.mean(axis=1)
if memer.compute_beat:
beat, beat_conf = MidTermFeatures.beat_extraction(s, memer.short_step)
mid_features = numpy.append(mid_features, beat)
mid_features = numpy.append(mid_features, beat_conf)
feature_vector = (mid_features - memer.mean) / memer.std
class_id, probability = classifier_wrapper(
memer.classifier, memer.model_details['type'].lower(), feature_vector
)
class_id = int(class_id) # faszom
return bool((class_id == memer.target_id) and (probability[class_id] > 0.5))
def lapatolas(q: Queue):
while True:
message = uwsgi.mule_get_msg()
q.put(message)
def reporter(q: Queue):
report = {
"site": Config.REPORT_ALIAS,
"measurements": {
"queue": q.qsize()
}
}
print("Reporting queue length of", report)
r = requests.post(Config.REPORT_URL, json=report)
r.raise_for_status()
if r.status_code != 201:
print(Config.REPORT_URL, "Wrong response:", r.status_code)
def main():
memer = ModelMemer()
requeue = Queue()
Thread(target=lapatolas, args=(requeue,), daemon=True).start()
scheduler = None
if Config.REPORT_URL:
scheduler = BackgroundScheduler()
scheduler.add_job(lambda: reporter(requeue), trigger='interval', seconds=Config.REPORT_INTERVAL)
scheduler.start()
Thread(target=reporter, args=(requeue,), daemon=True).start()
while True:
message = requeue.get(block=True)
task = pickle.loads(message)
audio_file_path = task['audio_file_path']
description = task['description']
try:
result = run_classification(audio_file_path, memer)
if result:
# upload to real input service
files = {
"file": (
os.path.basename(audio_file_path),
open(audio_file_path, 'rb').read(),
'audio/wave',
{'Content-length': os.path.getsize(audio_file_path)}
),
"description": (None, json.dumps(description, default=json_datetime_dumper), "application/json")
}
r = requests.post(Config.INPUT_SERVICE_URL, files=files)
r.raise_for_status()
finally:
os.remove(audio_file_path)
# if scheduler:
# scheduler.stop()
if __name__ == '__main__':
main()