This commit is contained in:
parent
e11e6ffc3b
commit
ef67ad3ba1
55
extractor_service/extraction.py
Normal file
55
extractor_service/extraction.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
import json
|
||||||
|
from json import JSONEncoder
|
||||||
|
import numpy
|
||||||
|
import os
|
||||||
|
import os.path
|
||||||
|
import logging
|
||||||
|
import requests
|
||||||
|
from pyAudioAnalysis import audioBasicIO
|
||||||
|
from pyAudioAnalysis import ShortTermFeatures
|
||||||
|
|
||||||
|
|
||||||
|
class NumpyArrayEncoder(JSONEncoder):
|
||||||
|
def default(self, obj):
|
||||||
|
if isinstance(obj, numpy.ndarray):
|
||||||
|
return obj.tolist()
|
||||||
|
return JSONEncoder.default(self, obj)
|
||||||
|
|
||||||
|
|
||||||
|
def do_extraction(file_path: str):
|
||||||
|
logging.info("Running extraction...")
|
||||||
|
|
||||||
|
[Fs, x] = audioBasicIO.read_audio_file(file_path)
|
||||||
|
F, f_names = ShortTermFeatures.feature_extraction(x, Fs, 0.050 * Fs, 0.025 * Fs)
|
||||||
|
|
||||||
|
return {"F": F, "f_names": f_names}
|
||||||
|
|
||||||
|
|
||||||
|
def run_everything(parameters: dict):
|
||||||
|
tag = parameters['tag']
|
||||||
|
logging.info(f"Downloading sample: {tag}")
|
||||||
|
|
||||||
|
file_path = os.path.join("/tmp/extractor-service/", f"{tag}.wav")
|
||||||
|
r = requests.get(f"http://storage-service/object/{tag}")
|
||||||
|
with open(file_path, 'wb') as f:
|
||||||
|
f.write(r.content)
|
||||||
|
|
||||||
|
# download done. Do extraction magic
|
||||||
|
try:
|
||||||
|
results = do_extraction(file_path)
|
||||||
|
finally:
|
||||||
|
os.remove(file_path)
|
||||||
|
|
||||||
|
logging.info(f"Pushing results to AI service...")
|
||||||
|
|
||||||
|
response = {
|
||||||
|
"tag": tag,
|
||||||
|
"results": results
|
||||||
|
}
|
||||||
|
|
||||||
|
logging.debug(f"Data being pushed: {str(response)}")
|
||||||
|
|
||||||
|
# r = requests.post('http://ai-service/asd', data=json.dumps(results, cls=NumpyArrayEncoder), headers={'Content-Type': 'application/json'})
|
||||||
|
|
||||||
|
# r.raise_for_status()
|
@ -3,16 +3,20 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import pika
|
import pika
|
||||||
|
import json
|
||||||
|
|
||||||
|
from extraction import run_everything
|
||||||
|
|
||||||
|
|
||||||
def message_callback(ch, method, properties, body):
|
def message_callback(ch, method, properties, body):
|
||||||
pass
|
run_everything(json.loads(body.decode('utf-8')))
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
logging.basicConfig(filename="", format="%(asctime)s - %(name)s [%(levelname)s]: %(message)s",
|
logging.basicConfig(filename="", format="%(asctime)s - %(name)s [%(levelname)s]: %(message)s",
|
||||||
level=logging.DEBUG if '--debug' in sys.argv else logging.INFO)
|
level=logging.DEBUG if '--debug' in sys.argv else logging.INFO)
|
||||||
|
|
||||||
|
logging.info("Connecting to MQ service...")
|
||||||
connection = pika.BlockingConnection(pika.connection.URLParameters(os.environ['PIKA_URL']))
|
connection = pika.BlockingConnection(pika.connection.URLParameters(os.environ['PIKA_URL']))
|
||||||
channel = connection.channel()
|
channel = connection.channel()
|
||||||
channel.exchange_declare(exchange='wave-extract', exchange_type='fanout')
|
channel.exchange_declare(exchange='wave-extract', exchange_type='fanout')
|
||||||
@ -23,9 +27,11 @@ def main():
|
|||||||
channel.queue_bind(exchange='wave-extract', queue=queue_name)
|
channel.queue_bind(exchange='wave-extract', queue=queue_name)
|
||||||
channel.basic_consume(queue=queue_name, on_message_callback=message_callback, auto_ack=True)
|
channel.basic_consume(queue=queue_name, on_message_callback=message_callback, auto_ack=True)
|
||||||
|
|
||||||
|
logging.info("Connection complete! Listening to messages...")
|
||||||
try:
|
try:
|
||||||
channel.start_consuming()
|
channel.start_consuming()
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
|
logging.info("SIGINT Received! Stopping stuff...")
|
||||||
channel.stop_consuming()
|
channel.stop_consuming()
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
sentry_sdk
|
sentry_sdk
|
||||||
pika
|
pika
|
||||||
requests
|
requests
|
||||||
|
pyAudioAnalysis
|
||||||
|
numpy
|
Reference in New Issue
Block a user