This commit is contained in:
parent
94b5066b16
commit
2bba3a4de9
@ -5,7 +5,6 @@ import logging
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import tempfile
|
import tempfile
|
||||||
from json import JSONEncoder
|
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from pyAudioAnalysis import audioBasicIO
|
from pyAudioAnalysis import audioBasicIO
|
||||||
@ -13,20 +12,14 @@ from pyAudioAnalysis import MidTermFeatures
|
|||||||
import numpy
|
import numpy
|
||||||
|
|
||||||
|
|
||||||
class NumpyArrayEncoder(JSONEncoder):
|
class NumpyArrayEncoder(json.JSONEncoder):
|
||||||
def default(self, obj):
|
def default(self, obj):
|
||||||
if isinstance(obj, numpy.ndarray):
|
if isinstance(obj, numpy.ndarray):
|
||||||
return obj.tolist()
|
return obj.tolist()
|
||||||
return JSONEncoder.default(self, obj)
|
return json.JSONEncoder.default(self, obj)
|
||||||
|
|
||||||
|
|
||||||
def do_extraction(file_path: str):
|
def do_extraction(model_details: dict, file_path: str):
|
||||||
logging.info("Getting default model details...")
|
|
||||||
r = requests.get("http://model-service/model/$default/details")
|
|
||||||
r.raise_for_status()
|
|
||||||
|
|
||||||
model_details = r.json()
|
|
||||||
|
|
||||||
logging.info("Running extraction...")
|
logging.info("Running extraction...")
|
||||||
|
|
||||||
sampling_rate, signal = audioBasicIO.read_audio_file(file_path)
|
sampling_rate, signal = audioBasicIO.read_audio_file(file_path)
|
||||||
@ -41,10 +34,10 @@ def do_extraction(file_path: str):
|
|||||||
# feature extraction:
|
# feature extraction:
|
||||||
mid_features, s, _ = \
|
mid_features, s, _ = \
|
||||||
MidTermFeatures.mid_feature_extraction(signal, sampling_rate,
|
MidTermFeatures.mid_feature_extraction(signal, sampling_rate,
|
||||||
model_details['mid_window'] * sampling_rate,
|
model_details['mid_window'] * sampling_rate,
|
||||||
model_details['mid_step'] * sampling_rate,
|
model_details['mid_step'] * sampling_rate,
|
||||||
round(sampling_rate * model_details['short_window']),
|
round(sampling_rate * model_details['short_window']),
|
||||||
round(sampling_rate * model_details['short_step']))
|
round(sampling_rate * model_details['short_step']))
|
||||||
|
|
||||||
# long term averaging of mid-term statistics
|
# long term averaging of mid-term statistics
|
||||||
mid_features = mid_features.mean(axis=1)
|
mid_features = mid_features.mean(axis=1)
|
||||||
@ -53,7 +46,7 @@ def do_extraction(file_path: str):
|
|||||||
mid_features = numpy.append(mid_features, beat)
|
mid_features = numpy.append(mid_features, beat)
|
||||||
mid_features = numpy.append(mid_features, beat_conf)
|
mid_features = numpy.append(mid_features, beat_conf)
|
||||||
|
|
||||||
#feature_vector = (mid_features - mean) / std # normalization
|
# feature_vector = (mid_features - mean) / std # normalization
|
||||||
|
|
||||||
return mid_features
|
return mid_features
|
||||||
|
|
||||||
@ -67,9 +60,19 @@ def run_everything(parameters: dict):
|
|||||||
with open(file_path, 'wb') as f:
|
with open(file_path, 'wb') as f:
|
||||||
f.write(r.content)
|
f.write(r.content)
|
||||||
|
|
||||||
|
logging.debug(f"Downloaded sample to {file_path}")
|
||||||
|
|
||||||
|
logging.info("Getting default model details...")
|
||||||
|
r = requests.get("http://model-service/model/$default/details")
|
||||||
|
r.raise_for_status()
|
||||||
|
|
||||||
|
model_details = r.json()
|
||||||
|
|
||||||
|
logging.debug(f"Using model {model_details['id']}")
|
||||||
|
|
||||||
# download done. Do extraction magic
|
# download done. Do extraction magic
|
||||||
try:
|
try:
|
||||||
results = do_extraction(file_path)
|
results = do_extraction(model_details, file_path)
|
||||||
finally:
|
finally:
|
||||||
os.remove(file_path)
|
os.remove(file_path)
|
||||||
|
|
||||||
@ -77,11 +80,17 @@ def run_everything(parameters: dict):
|
|||||||
|
|
||||||
response = {
|
response = {
|
||||||
"tag": tag,
|
"tag": tag,
|
||||||
"results": results
|
"results": results,
|
||||||
|
"model": model_details['id']
|
||||||
}
|
}
|
||||||
|
|
||||||
logging.debug(f"Data being pushed: {str(response)}")
|
logging.debug(f"Data being pushed: {str(response)}")
|
||||||
|
|
||||||
r = requests.post('http://classification-service/classify', data=json.dumps(results, cls=NumpyArrayEncoder), headers={'Content-Type': 'application/json'})
|
r = requests.post(
|
||||||
#r.raise_for_status() # An error in a service should not kill other services
|
'http://classification-service/classify',
|
||||||
|
data=json.dumps(response, cls=NumpyArrayEncoder),
|
||||||
|
headers={'Content-Type': 'application/json'}
|
||||||
|
)
|
||||||
|
|
||||||
|
# r.raise_for_status() # An error in a service should not kill other services
|
||||||
logging.info(f"Classification service response: {r.status_code}")
|
logging.info(f"Classification service response: {r.status_code}")
|
||||||
|
Reference in New Issue
Block a user