4
0

Finished main stuff
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
Pünkösd Marcell 2020-04-19 21:17:32 +02:00
parent 94b5066b16
commit 2bba3a4de9

View File

@ -5,7 +5,6 @@ import logging
import json import json
import tempfile import tempfile
from json import JSONEncoder
import requests import requests
from pyAudioAnalysis import audioBasicIO from pyAudioAnalysis import audioBasicIO
@ -13,20 +12,14 @@ from pyAudioAnalysis import MidTermFeatures
import numpy import numpy
class NumpyArrayEncoder(JSONEncoder): class NumpyArrayEncoder(json.JSONEncoder):
def default(self, obj): def default(self, obj):
if isinstance(obj, numpy.ndarray): if isinstance(obj, numpy.ndarray):
return obj.tolist() return obj.tolist()
return JSONEncoder.default(self, obj) return json.JSONEncoder.default(self, obj)
def do_extraction(file_path: str): def do_extraction(model_details: dict, file_path: str):
logging.info("Getting default model details...")
r = requests.get("http://model-service/model/$default/details")
r.raise_for_status()
model_details = r.json()
logging.info("Running extraction...") logging.info("Running extraction...")
sampling_rate, signal = audioBasicIO.read_audio_file(file_path) sampling_rate, signal = audioBasicIO.read_audio_file(file_path)
@ -41,10 +34,10 @@ def do_extraction(file_path: str):
# feature extraction: # feature extraction:
mid_features, s, _ = \ mid_features, s, _ = \
MidTermFeatures.mid_feature_extraction(signal, sampling_rate, MidTermFeatures.mid_feature_extraction(signal, sampling_rate,
model_details['mid_window'] * sampling_rate, model_details['mid_window'] * sampling_rate,
model_details['mid_step'] * sampling_rate, model_details['mid_step'] * sampling_rate,
round(sampling_rate * model_details['short_window']), round(sampling_rate * model_details['short_window']),
round(sampling_rate * model_details['short_step'])) round(sampling_rate * model_details['short_step']))
# long term averaging of mid-term statistics # long term averaging of mid-term statistics
mid_features = mid_features.mean(axis=1) mid_features = mid_features.mean(axis=1)
@ -53,7 +46,7 @@ def do_extraction(file_path: str):
mid_features = numpy.append(mid_features, beat) mid_features = numpy.append(mid_features, beat)
mid_features = numpy.append(mid_features, beat_conf) mid_features = numpy.append(mid_features, beat_conf)
#feature_vector = (mid_features - mean) / std # normalization # feature_vector = (mid_features - mean) / std # normalization
return mid_features return mid_features
@ -67,9 +60,19 @@ def run_everything(parameters: dict):
with open(file_path, 'wb') as f: with open(file_path, 'wb') as f:
f.write(r.content) f.write(r.content)
logging.debug(f"Downloaded sample to {file_path}")
logging.info("Getting default model details...")
r = requests.get("http://model-service/model/$default/details")
r.raise_for_status()
model_details = r.json()
logging.debug(f"Using model {model_details['id']}")
# download done. Do extraction magic # download done. Do extraction magic
try: try:
results = do_extraction(file_path) results = do_extraction(model_details, file_path)
finally: finally:
os.remove(file_path) os.remove(file_path)
@ -77,11 +80,17 @@ def run_everything(parameters: dict):
response = { response = {
"tag": tag, "tag": tag,
"results": results "results": results,
"model": model_details['id']
} }
logging.debug(f"Data being pushed: {str(response)}") logging.debug(f"Data being pushed: {str(response)}")
r = requests.post('http://classification-service/classify', data=json.dumps(results, cls=NumpyArrayEncoder), headers={'Content-Type': 'application/json'}) r = requests.post(
#r.raise_for_status() # An error in a service should not kill other services 'http://classification-service/classify',
data=json.dumps(response, cls=NumpyArrayEncoder),
headers={'Content-Type': 'application/json'}
)
# r.raise_for_status() # An error in a service should not kill other services
logging.info(f"Classification service response: {r.status_code}") logging.info(f"Classification service response: {r.status_code}")