diff --git a/extractor_service/extraction.py b/extractor_service/extraction.py new file mode 100644 index 0000000..3f7e363 --- /dev/null +++ b/extractor_service/extraction.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +import json +from json import JSONEncoder +import numpy +import os +import os.path +import logging +import requests +from pyAudioAnalysis import audioBasicIO +from pyAudioAnalysis import ShortTermFeatures + + +class NumpyArrayEncoder(JSONEncoder): + def default(self, obj): + if isinstance(obj, numpy.ndarray): + return obj.tolist() + return JSONEncoder.default(self, obj) + + +def do_extraction(file_path: str): + logging.info("Running extraction...") + + [Fs, x] = audioBasicIO.read_audio_file(file_path) + F, f_names = ShortTermFeatures.feature_extraction(x, Fs, 0.050 * Fs, 0.025 * Fs) + + return {"F": F, "f_names": f_names} + + +def run_everything(parameters: dict): + tag = parameters['tag'] + logging.info(f"Downloading sample: {tag}") + + file_path = os.path.join("/tmp/extractor-service/", f"{tag}.wav") + r = requests.get(f"http://storage-service/object/{tag}") + with open(file_path, 'wb') as f: + f.write(r.content) + + # download done. Do extraction magic + try: + results = do_extraction(file_path) + finally: + os.remove(file_path) + + logging.info(f"Pushing results to AI service...") + + response = { + "tag": tag, + "results": results + } + + logging.debug(f"Data being pushed: {str(response)}") + + # r = requests.post('http://ai-service/asd', data=json.dumps(results, cls=NumpyArrayEncoder), headers={'Content-Type': 'application/json'}) + + # r.raise_for_status() diff --git a/extractor_service/main.py b/extractor_service/main.py index 9d8161a..f1fde2d 100644 --- a/extractor_service/main.py +++ b/extractor_service/main.py @@ -3,16 +3,20 @@ import logging import os import sys import pika +import json + +from extraction import run_everything def message_callback(ch, method, properties, body): - pass + run_everything(json.loads(body.decode('utf-8'))) def main(): logging.basicConfig(filename="", format="%(asctime)s - %(name)s [%(levelname)s]: %(message)s", level=logging.DEBUG if '--debug' in sys.argv else logging.INFO) + logging.info("Connecting to MQ service...") connection = pika.BlockingConnection(pika.connection.URLParameters(os.environ['PIKA_URL'])) channel = connection.channel() channel.exchange_declare(exchange='wave-extract', exchange_type='fanout') @@ -23,9 +27,11 @@ def main(): channel.queue_bind(exchange='wave-extract', queue=queue_name) channel.basic_consume(queue=queue_name, on_message_callback=message_callback, auto_ack=True) + logging.info("Connection complete! Listening to messages...") try: channel.start_consuming() except KeyboardInterrupt: + logging.info("SIGINT Received! Stopping stuff...") channel.stop_consuming() diff --git a/requirements.txt b/requirements.txt index 4c1fdcf..baee911 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,5 @@ sentry_sdk pika -requests \ No newline at end of file +requests +pyAudioAnalysis +numpy \ No newline at end of file