This commit is contained in:
parent
851f451354
commit
fc36a08b70
@ -1,6 +1,6 @@
|
||||
FROM python:3.9
|
||||
|
||||
ADD svm_prefilter_service requirements.txt /svm_prefilter_service/
|
||||
ADD svm_prefilter_service requirements.txt uwsgi.ini /svm_prefilter_service/
|
||||
WORKDIR /svm_prefilter_service/
|
||||
|
||||
ENV PIP_NO_CACHE_DIR=true
|
||||
@ -12,5 +12,5 @@ RUN pip3 install -r requirements.txt
|
||||
ENV GUNICORN_LOGLEVEL="info"
|
||||
|
||||
EXPOSE 8000
|
||||
CMD ["gunicorn", "-b", "0.0.0.0:8000", "--log-level", "${GUNICORN_LOGLEVEL}", "app:app"]
|
||||
CMD ["uwsgi", "--ini", "uwsgi.ini"]
|
||||
|
||||
|
@ -4,7 +4,7 @@ blinker
|
||||
Flask~=2.0.1
|
||||
marshmallow~=3.14.1
|
||||
Flask-Classful
|
||||
gunicorn
|
||||
uwsgi
|
||||
sentry_sdk
|
||||
py-healthcheck
|
||||
|
||||
|
@ -7,3 +7,7 @@ class Config:
|
||||
SENTRY_DSN = os.environ.get("SENTRY_DSN")
|
||||
RELEASE_ID = os.environ.get("RELEASE_ID", "test")
|
||||
RELEASEMODE = os.environ.get("RELEASEMODE", "dev")
|
||||
MODEL_INFO_URL = os.environ.get("MODEL_INFO_URL", "http://model-service/model/svm/$default")
|
||||
INPUT_SERVICE_URL = os.environ.get("INPUT_SERVICE_URL", "http://input-service/input")
|
||||
|
||||
DROPALL = os.environ.get("DROPALL", "no").lower() in ['yes', 'true', '1']
|
||||
|
136
svm_prefilter_service/mule.py
Normal file
136
svm_prefilter_service/mule.py
Normal file
@ -0,0 +1,136 @@
|
||||
#!/usr/bin/env python3
|
||||
import sentry_sdk
|
||||
import os
|
||||
import requests
|
||||
import tempfile
|
||||
import numpy
|
||||
import json
|
||||
import uwsgi
|
||||
|
||||
from config import Config
|
||||
from pyAudioAnalysis.audioTrainTest import load_model, load_model_knn, classifier_wrapper
|
||||
from pyAudioAnalysis import audioBasicIO
|
||||
from pyAudioAnalysis import MidTermFeatures
|
||||
import numpy
|
||||
|
||||
if Config.SENTRY_DSN:
|
||||
sentry_sdk.init(
|
||||
dsn=Config.SENTRY_DSN,
|
||||
send_default_pii=True,
|
||||
release=Config.RELEASE_ID,
|
||||
environment=Config.RELEASEMODE
|
||||
)
|
||||
|
||||
|
||||
class ModelMemer:
|
||||
|
||||
def __init__(self):
|
||||
self._loaded_model = None
|
||||
|
||||
def download_model_if_needed(self):
|
||||
models_dir = "/tmp/svm_model"
|
||||
os.makedirs(models_dir, exist_ok=True)
|
||||
|
||||
model_file = os.path.join(models_dir, "model")
|
||||
means_file = os.path.join(models_dir, "modelMEANS")
|
||||
|
||||
if os.path.isfile(model_file) and self._loaded_model:
|
||||
return
|
||||
|
||||
r = requests.get(Config.MODEL_INFO_URL)
|
||||
r.raise_for_status()
|
||||
self.model_details = r.json()
|
||||
|
||||
r = requests.get(self.model_details['files']['model'])
|
||||
r.raise_for_status()
|
||||
|
||||
with open(model_file, 'wb') as f:
|
||||
f.write(r.content)
|
||||
|
||||
r = requests.get(self.model_details['files']['means'])
|
||||
r.raise_for_status()
|
||||
|
||||
with open(means_file, 'wb') as f:
|
||||
f.write(r.content)
|
||||
|
||||
if self.model_details['type'] == 'knn':
|
||||
self.classifier, self.mean, self.std, self.classes, self.mid_window, self.mid_step, self.short_window, \
|
||||
self.short_step, self.compute_beat \
|
||||
= load_model_knn(model_file)
|
||||
|
||||
else:
|
||||
self.classifier, self.mean, self.std, self.classes, self.mid_window, self.mid_step, self.short_window, \
|
||||
self.short_step, self.compute_beat \
|
||||
= load_model(model_file)
|
||||
|
||||
target_class_name = self.model_details['target_class_name']
|
||||
self.target_id = self.classes.index(target_class_name)
|
||||
|
||||
|
||||
def run_classification(audio_file_path: str, memer: ModelMemer):
|
||||
memer.download_model_if_needed()
|
||||
|
||||
# run extraction
|
||||
sampling_rate, signal = audioBasicIO.read_audio_file(audio_file_path)
|
||||
signal = audioBasicIO.stereo_to_mono(signal)
|
||||
|
||||
if sampling_rate == 0:
|
||||
raise AssertionError("Could not read the file properly: Sampling rate zero")
|
||||
|
||||
if signal.shape[0] / float(sampling_rate) <= memer.mid_window:
|
||||
raise AssertionError("Could not read the file properly: Signal shape is not good")
|
||||
|
||||
# feature extraction:
|
||||
mid_features, s, _ = \
|
||||
MidTermFeatures.mid_feature_extraction(signal, sampling_rate,
|
||||
memer.mid_window * sampling_rate,
|
||||
memer.mid_step * sampling_rate,
|
||||
round(sampling_rate * memer.short_window),
|
||||
round(sampling_rate * memer.short_step))
|
||||
|
||||
# long term averaging of mid-term statistics
|
||||
mid_features = mid_features.mean(axis=1)
|
||||
if memer.compute_beat:
|
||||
beat, beat_conf = MidTermFeatures.beat_extraction(s, memer.short_step)
|
||||
mid_features = numpy.append(mid_features, beat)
|
||||
mid_features = numpy.append(mid_features, beat_conf)
|
||||
|
||||
feature_vector = (mid_features - memer.mean) / memer.std
|
||||
class_id, probability = classifier_wrapper(
|
||||
memer.classifier, memer.model_details['type'].lower(), feature_vector
|
||||
)
|
||||
class_id = int(class_id) # faszom
|
||||
|
||||
return bool((class_id == memer.target_id) and (probability[class_id] > 0.5))
|
||||
|
||||
|
||||
def main():
|
||||
memer = ModelMemer()
|
||||
|
||||
while True:
|
||||
message = uwsgi.mule_get_msg()
|
||||
task = json.loads(message)
|
||||
audio_file_path = task['audio_file_path']
|
||||
description = task['description']
|
||||
try:
|
||||
result = run_classification(audio_file_path, memer)
|
||||
if result:
|
||||
# upload to real input service
|
||||
files = {
|
||||
"file": (
|
||||
os.path.basename(audio_file_path),
|
||||
open(audio_file_path, 'rb').read(),
|
||||
'audio/wave',
|
||||
{'Content-length': os.path.getsize(audio_file_path)}
|
||||
),
|
||||
"description": (None, json.dumps(description), "application/json")
|
||||
}
|
||||
|
||||
r = requests.post(Config.INPUT_SERVICE_URL, files=files)
|
||||
r.raise_for_status()
|
||||
finally:
|
||||
os.remove(audio_file_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -1 +1,2 @@
|
||||
#!/usr/bin/env python3
|
||||
from .sample_schema import SampleSchema
|
||||
|
7
svm_prefilter_service/schemas/sample_schema.py
Normal file
7
svm_prefilter_service/schemas/sample_schema.py
Normal file
@ -0,0 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
from marshmallow import fields, Schema
|
||||
|
||||
|
||||
class SampleSchema(Schema):
|
||||
date = fields.DateTime(required=True)
|
||||
device_id = fields.Integer(required=True)
|
@ -1,13 +1,51 @@
|
||||
#!/usr/bin/env python3
|
||||
import tempfile
|
||||
|
||||
from flask import jsonify, request, abort, current_app, Response
|
||||
from flask_classful import FlaskView
|
||||
from utils import json_required
|
||||
import opentracing
|
||||
from schemas import SampleSchema
|
||||
import json
|
||||
import uwsgi
|
||||
|
||||
|
||||
class FilterView(FlaskView):
|
||||
sampleschema = SampleSchema(many=False)
|
||||
|
||||
@json_required
|
||||
def post(self):
|
||||
data = request.json
|
||||
if current_app.config.get('DROPALL'):
|
||||
return Response(status=200)
|
||||
|
||||
return Response(status=201)
|
||||
with opentracing.tracer.start_active_span('parseAndValidate'):
|
||||
if 'file' not in request.files:
|
||||
return abort(400, "no file found")
|
||||
else:
|
||||
soundfile = request.files['file']
|
||||
|
||||
if 'description' not in request.form:
|
||||
return abort(400, "no description found")
|
||||
else:
|
||||
description_raw = request.form.get("description")
|
||||
|
||||
if soundfile.content_type != 'audio/wave':
|
||||
current_app.logger.info(f"Input file was not WAV.")
|
||||
return abort(415, 'Input file not a wave file.')
|
||||
try:
|
||||
desc = self.sampleschema.loads(description_raw)
|
||||
except Exception as e:
|
||||
current_app.logger.exception(e)
|
||||
return abort(417, 'Input JSON schema invalid')
|
||||
|
||||
soundfile_handle, soundfile_path = tempfile.mkstemp()
|
||||
soundfile.save(open(soundfile_handle, "wb+"))
|
||||
|
||||
task = {
|
||||
"audio_file_path": soundfile_path,
|
||||
"description": desc
|
||||
}
|
||||
|
||||
uwsgi.mule_msg(json.dumps(task))
|
||||
|
||||
return Response(status=200)
|
||||
|
Loading…
Reference in New Issue
Block a user