diff --git a/extractor_service/extraction.py b/extractor_service/extraction.py index c85421b..c1e6a8d 100644 --- a/extractor_service/extraction.py +++ b/extractor_service/extraction.py @@ -5,7 +5,6 @@ import logging import json import tempfile -from json import JSONEncoder import requests from pyAudioAnalysis import audioBasicIO @@ -13,20 +12,14 @@ from pyAudioAnalysis import MidTermFeatures import numpy -class NumpyArrayEncoder(JSONEncoder): +class NumpyArrayEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, numpy.ndarray): return obj.tolist() - return JSONEncoder.default(self, obj) + return json.JSONEncoder.default(self, obj) -def do_extraction(file_path: str): - logging.info("Getting default model details...") - r = requests.get("http://model-service/model/$default/details") - r.raise_for_status() - - model_details = r.json() - +def do_extraction(model_details: dict, file_path: str): logging.info("Running extraction...") sampling_rate, signal = audioBasicIO.read_audio_file(file_path) @@ -41,10 +34,10 @@ def do_extraction(file_path: str): # feature extraction: mid_features, s, _ = \ MidTermFeatures.mid_feature_extraction(signal, sampling_rate, - model_details['mid_window'] * sampling_rate, - model_details['mid_step'] * sampling_rate, - round(sampling_rate * model_details['short_window']), - round(sampling_rate * model_details['short_step'])) + model_details['mid_window'] * sampling_rate, + model_details['mid_step'] * sampling_rate, + round(sampling_rate * model_details['short_window']), + round(sampling_rate * model_details['short_step'])) # long term averaging of mid-term statistics mid_features = mid_features.mean(axis=1) @@ -53,7 +46,7 @@ def do_extraction(file_path: str): mid_features = numpy.append(mid_features, beat) mid_features = numpy.append(mid_features, beat_conf) - #feature_vector = (mid_features - mean) / std # normalization + # feature_vector = (mid_features - mean) / std # normalization return mid_features @@ -67,9 +60,19 @@ def run_everything(parameters: dict): with open(file_path, 'wb') as f: f.write(r.content) + logging.debug(f"Downloaded sample to {file_path}") + + logging.info("Getting default model details...") + r = requests.get("http://model-service/model/$default/details") + r.raise_for_status() + + model_details = r.json() + + logging.debug(f"Using model {model_details['id']}") + # download done. Do extraction magic try: - results = do_extraction(file_path) + results = do_extraction(model_details, file_path) finally: os.remove(file_path) @@ -77,11 +80,17 @@ def run_everything(parameters: dict): response = { "tag": tag, - "results": results + "results": results, + "model": model_details['id'] } logging.debug(f"Data being pushed: {str(response)}") - r = requests.post('http://classification-service/classify', data=json.dumps(results, cls=NumpyArrayEncoder), headers={'Content-Type': 'application/json'}) - #r.raise_for_status() # An error in a service should not kill other services + r = requests.post( + 'http://classification-service/classify', + data=json.dumps(response, cls=NumpyArrayEncoder), + headers={'Content-Type': 'application/json'} + ) + + # r.raise_for_status() # An error in a service should not kill other services logging.info(f"Classification service response: {r.status_code}")