This commit is contained in:
		
							
								
								
									
										55
									
								
								extractor_service/extraction.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										55
									
								
								extractor_service/extraction.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,55 @@
 | 
				
			|||||||
 | 
					#!/usr/bin/env python3
 | 
				
			||||||
 | 
					import json
 | 
				
			||||||
 | 
					from json import JSONEncoder
 | 
				
			||||||
 | 
					import numpy
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import os.path
 | 
				
			||||||
 | 
					import logging
 | 
				
			||||||
 | 
					import requests
 | 
				
			||||||
 | 
					from pyAudioAnalysis import audioBasicIO
 | 
				
			||||||
 | 
					from pyAudioAnalysis import ShortTermFeatures
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class NumpyArrayEncoder(JSONEncoder):
 | 
				
			||||||
 | 
					    def default(self, obj):
 | 
				
			||||||
 | 
					        if isinstance(obj, numpy.ndarray):
 | 
				
			||||||
 | 
					            return obj.tolist()
 | 
				
			||||||
 | 
					        return JSONEncoder.default(self, obj)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def do_extraction(file_path: str):
 | 
				
			||||||
 | 
					    logging.info("Running extraction...")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    [Fs, x] = audioBasicIO.read_audio_file(file_path)
 | 
				
			||||||
 | 
					    F, f_names = ShortTermFeatures.feature_extraction(x, Fs, 0.050 * Fs, 0.025 * Fs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return {"F": F, "f_names": f_names}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def run_everything(parameters: dict):
 | 
				
			||||||
 | 
					    tag = parameters['tag']
 | 
				
			||||||
 | 
					    logging.info(f"Downloading sample: {tag}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    file_path = os.path.join("/tmp/extractor-service/", f"{tag}.wav")
 | 
				
			||||||
 | 
					    r = requests.get(f"http://storage-service/object/{tag}")
 | 
				
			||||||
 | 
					    with open(file_path, 'wb') as f:
 | 
				
			||||||
 | 
					        f.write(r.content)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # download done. Do extraction magic
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        results = do_extraction(file_path)
 | 
				
			||||||
 | 
					    finally:
 | 
				
			||||||
 | 
					        os.remove(file_path)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    logging.info(f"Pushing results to AI service...")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    response = {
 | 
				
			||||||
 | 
					        "tag": tag,
 | 
				
			||||||
 | 
					        "results": results
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    logging.debug(f"Data being pushed: {str(response)}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # r = requests.post('http://ai-service/asd', data=json.dumps(results, cls=NumpyArrayEncoder), headers={'Content-Type': 'application/json'})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # r.raise_for_status()
 | 
				
			||||||
@@ -3,16 +3,20 @@ import logging
 | 
				
			|||||||
import os
 | 
					import os
 | 
				
			||||||
import sys
 | 
					import sys
 | 
				
			||||||
import pika
 | 
					import pika
 | 
				
			||||||
 | 
					import json
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from extraction import run_everything
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def message_callback(ch, method, properties, body):
 | 
					def message_callback(ch, method, properties, body):
 | 
				
			||||||
    pass
 | 
					    run_everything(json.loads(body.decode('utf-8')))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def main():
 | 
					def main():
 | 
				
			||||||
    logging.basicConfig(filename="", format="%(asctime)s - %(name)s [%(levelname)s]: %(message)s",
 | 
					    logging.basicConfig(filename="", format="%(asctime)s - %(name)s [%(levelname)s]: %(message)s",
 | 
				
			||||||
                        level=logging.DEBUG if '--debug' in sys.argv else logging.INFO)
 | 
					                        level=logging.DEBUG if '--debug' in sys.argv else logging.INFO)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    logging.info("Connecting to MQ service...")
 | 
				
			||||||
    connection = pika.BlockingConnection(pika.connection.URLParameters(os.environ['PIKA_URL']))
 | 
					    connection = pika.BlockingConnection(pika.connection.URLParameters(os.environ['PIKA_URL']))
 | 
				
			||||||
    channel = connection.channel()
 | 
					    channel = connection.channel()
 | 
				
			||||||
    channel.exchange_declare(exchange='wave-extract', exchange_type='fanout')
 | 
					    channel.exchange_declare(exchange='wave-extract', exchange_type='fanout')
 | 
				
			||||||
@@ -23,9 +27,11 @@ def main():
 | 
				
			|||||||
    channel.queue_bind(exchange='wave-extract', queue=queue_name)
 | 
					    channel.queue_bind(exchange='wave-extract', queue=queue_name)
 | 
				
			||||||
    channel.basic_consume(queue=queue_name, on_message_callback=message_callback, auto_ack=True)
 | 
					    channel.basic_consume(queue=queue_name, on_message_callback=message_callback, auto_ack=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    logging.info("Connection complete! Listening to messages...")
 | 
				
			||||||
    try:
 | 
					    try:
 | 
				
			||||||
        channel.start_consuming()
 | 
					        channel.start_consuming()
 | 
				
			||||||
    except KeyboardInterrupt:
 | 
					    except KeyboardInterrupt:
 | 
				
			||||||
 | 
					        logging.info("SIGINT Received! Stopping stuff...")
 | 
				
			||||||
        channel.stop_consuming()
 | 
					        channel.stop_consuming()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,3 +1,5 @@
 | 
				
			|||||||
sentry_sdk
 | 
					sentry_sdk
 | 
				
			||||||
pika
 | 
					pika
 | 
				
			||||||
requests
 | 
					requests
 | 
				
			||||||
 | 
					pyAudioAnalysis
 | 
				
			||||||
 | 
					numpy
 | 
				
			||||||
		Reference in New Issue
	
	Block a user