svm-prefilter-service/svm_prefilter_service/mule.py

194 lines
6.1 KiB
Python

#!/usr/bin/env python3
import sentry_sdk
import os
import requests
import json
import uwsgi
import pickle
from threading import Thread
from queue import Queue
from urllib.parse import urljoin
from config import Config
from pyAudioAnalysis.audioTrainTest import load_model, load_model_knn, classifier_wrapper
from pyAudioAnalysis import audioBasicIO
from pyAudioAnalysis import MidTermFeatures
import numpy
from apscheduler.schedulers.background import BackgroundScheduler
if Config.SENTRY_DSN:
sentry_sdk.init(
dsn=Config.SENTRY_DSN,
send_default_pii=True,
release=Config.RELEASE_ID,
environment=Config.RELEASEMODE
)
def json_datetime_dumper(o):
if hasattr(o, "isoformat"):
return o.isoformat()
else:
return str(o)
class ModelMemer:
def __init__(self):
self._loaded_model = False
def download_model_if_needed(self):
models_dir = "/tmp/svm_model"
os.makedirs(models_dir, exist_ok=True)
model_file = os.path.join(models_dir, "model")
means_file = os.path.join(models_dir, "modelMEANS")
if os.path.isfile(model_file) and self._loaded_model:
return
print("Model file needs to be downloaded. Doing it...", flush=True)
r = requests.get(Config.MODEL_INFO_URL)
r.raise_for_status()
self.model_details = r.json()
r = requests.get(urljoin(Config.MODEL_INFO_URL, self.model_details['files']['model']))
r.raise_for_status()
with open(model_file, 'wb') as f:
f.write(r.content)
r = requests.get(urljoin(Config.MODEL_INFO_URL, self.model_details['files']['means']))
r.raise_for_status()
with open(means_file, 'wb') as f:
f.write(r.content)
if self.model_details['type'] == 'knn':
self.classifier, self.mean, self.std, self.classes, self.mid_window, self.mid_step, self.short_window, \
self.short_step, self.compute_beat \
= load_model_knn(model_file)
else:
self.classifier, self.mean, self.std, self.classes, self.mid_window, self.mid_step, self.short_window, \
self.short_step, self.compute_beat \
= load_model(model_file)
target_class_name = self.model_details['target_class_name']
self.target_id = self.classes.index(target_class_name)
self._loaded_model = True
def run_classification(audio_file_path: str, memer: ModelMemer):
memer.download_model_if_needed()
# run extraction
sampling_rate, signal = audioBasicIO.read_audio_file(audio_file_path)
signal = audioBasicIO.stereo_to_mono(signal)
if sampling_rate == 0:
raise AssertionError("Could not read the file properly: Sampling rate zero")
if signal.shape[0] / float(sampling_rate) <= memer.mid_window:
raise AssertionError("Could not read the file properly: Signal shape is not good")
# feature extraction:
mid_features, s, _ = \
MidTermFeatures.mid_feature_extraction(signal, sampling_rate,
memer.mid_window * sampling_rate,
memer.mid_step * sampling_rate,
round(sampling_rate * memer.short_window),
round(sampling_rate * memer.short_step))
# long term averaging of mid-term statistics
mid_features = mid_features.mean(axis=1)
if memer.compute_beat:
beat, beat_conf = MidTermFeatures.beat_extraction(s, memer.short_step)
mid_features = numpy.append(mid_features, beat)
mid_features = numpy.append(mid_features, beat_conf)
feature_vector = (mid_features - memer.mean) / memer.std
class_id, probability = classifier_wrapper(
memer.classifier, memer.model_details['type'].lower(), feature_vector
)
class_id = int(class_id) # faszom
return bool((class_id == memer.target_id) and (probability[class_id] > 0.5))
def lapatolas(q: Queue):
while True:
message = uwsgi.mule_get_msg()
q.put(message)
def reporter(q: Queue):
report = {
"site": Config.REPORT_ALIAS,
"measurements": {
"queue": q.qsize()
}
}
print("Reporting queue length of", report)
r = requests.post(Config.REPORT_URL, json=report)
r.raise_for_status()
if r.status_code != 201:
print(Config.REPORT_URL, "Wrong response:", r.status_code)
def main():
memer = ModelMemer()
requeue = Queue()
Thread(target=lapatolas, args=(requeue,), daemon=True).start()
scheduler = None
if Config.REPORT_URL:
scheduler = BackgroundScheduler()
scheduler.add_job(lambda: reporter(requeue), trigger='interval', seconds=Config.REPORT_INTERVAL)
scheduler.start()
Thread(target=reporter, args=(requeue,), daemon=True).start()
while True:
message = requeue.get(block=True)
print('New sample recieved... start classifying!', flush=True)
task = pickle.loads(message)
audio_file_path = task['audio_file_path']
description = task['description']
try:
result = run_classification(audio_file_path, memer)
print("Result of classification:", result, flush=True)
if result:
# upload to real input service
files = {
"file": (
os.path.basename(audio_file_path),
open(audio_file_path, 'rb').read(),
'audio/wave',
{'Content-length': os.path.getsize(audio_file_path)}
),
"description": (None, json.dumps(description, default=json_datetime_dumper), "application/json")
}
r = requests.post(Config.INPUT_SERVICE_URL, files=files)
r.raise_for_status()
print("Upload response status:", r.status_code, flush=True)
print("Upload response body:", r.content, flush=True)
finally:
os.remove(audio_file_path)
# if scheduler:
# scheduler.stop()
if __name__ == '__main__':
main()