Finished stuff
This commit is contained in:
		@@ -1,3 +1,4 @@
 | 
			
		||||
#!/usr/bin/env python3
 | 
			
		||||
from typing import Tuple
 | 
			
		||||
import tempfile
 | 
			
		||||
import os
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										67
									
								
								cnn_classification_service/magic_doer.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										67
									
								
								cnn_classification_service/magic_doer.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,67 @@
 | 
			
		||||
#!/usr/bin/env python3
 | 
			
		||||
import os
 | 
			
		||||
import logging
 | 
			
		||||
import tempfile
 | 
			
		||||
import requests
 | 
			
		||||
 | 
			
		||||
from cnn_classifier import Classifier
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_everything(parameters: dict):
 | 
			
		||||
    tag = parameters['tag']
 | 
			
		||||
 | 
			
		||||
    _, file_path = tempfile.mkstemp(prefix=f"{tag}_", suffix=".wav")
 | 
			
		||||
    _, temp_model_name = tempfile.mkstemp(suffix=".json")
 | 
			
		||||
    _, temp_weights_name = tempfile.mkstemp(suffix=".h5")
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
 | 
			
		||||
        logging.info(f"Downloading sample: {tag}")
 | 
			
		||||
        r = requests.get(f"http://storage-service/object/{tag}")
 | 
			
		||||
        with open(file_path, 'wb') as f:
 | 
			
		||||
            f.write(r.content)
 | 
			
		||||
 | 
			
		||||
        logging.debug(f"Downloaded sample to {file_path}")
 | 
			
		||||
 | 
			
		||||
        r = requests.get(f"http://model-service/model/cnn/$default")
 | 
			
		||||
        r.raise_for_status()
 | 
			
		||||
 | 
			
		||||
        with open(temp_model_name, 'wb') as f:
 | 
			
		||||
            f.write(r.content)
 | 
			
		||||
 | 
			
		||||
        r = requests.get(f"http://model-service/model/cnn/$default?weights")
 | 
			
		||||
        r.raise_for_status()
 | 
			
		||||
 | 
			
		||||
        with open(temp_weights_name, 'wb') as f:
 | 
			
		||||
            f.write(r.content)
 | 
			
		||||
 | 
			
		||||
        # magic happens here
 | 
			
		||||
        classifier = Classifier(temp_model_name, temp_weights_name)
 | 
			
		||||
        results = classifier.predict(file_path)
 | 
			
		||||
 | 
			
		||||
    finally:  # bruuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuh
 | 
			
		||||
        try:
 | 
			
		||||
            os.remove(temp_model_name)
 | 
			
		||||
        except FileNotFoundError:
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            os.remove(temp_weights_name)
 | 
			
		||||
        except FileNotFoundError:
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            os.remove(file_path)
 | 
			
		||||
        except FileNotFoundError:
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
    response = {
 | 
			
		||||
        "tag": tag,
 | 
			
		||||
        "probability": 1.0 if results[0] == 'sturnus' else 0.0,
 | 
			
		||||
        "model": ...
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    logging.info(f"Classification done!")
 | 
			
		||||
    logging.debug(f"Results: {response}")
 | 
			
		||||
 | 
			
		||||
    return response
 | 
			
		||||
@@ -8,12 +8,19 @@ import json
 | 
			
		||||
from sentry_sdk.integrations.logging import LoggingIntegration
 | 
			
		||||
import sentry_sdk
 | 
			
		||||
 | 
			
		||||
from cnn_classifier import Classifier
 | 
			
		||||
from magic_doer import run_everything
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def message_callback(ch, method, properties, body):
 | 
			
		||||
    msg = json.loads(body.decode('utf-8'))
 | 
			
		||||
    # TODO
 | 
			
		||||
    results = run_everything(msg)
 | 
			
		||||
 | 
			
		||||
    # TODO: Ez azért elég gettó, de legalább csatlakozik
 | 
			
		||||
    connection = pika.BlockingConnection(pika.connection.URLParameters(os.environ['PIKA_URL']))
 | 
			
		||||
    channel = connection.channel()
 | 
			
		||||
    channel.exchange_declare(exchange=os.environ['PIKA_OUTPUT_EXCHANGE'], exchange_type='fanout')
 | 
			
		||||
    channel.basic_publish(exchange=os.environ['PIKA_OUTPUT_EXCHANGE'], routing_key='classification-result',
 | 
			
		||||
                          body=json.dumps(results).encode("utf-8"))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
@@ -37,12 +44,12 @@ def main():
 | 
			
		||||
    logging.info("Connecting to MQ service...")
 | 
			
		||||
    connection = pika.BlockingConnection(pika.connection.URLParameters(os.environ['PIKA_URL']))
 | 
			
		||||
    channel = connection.channel()
 | 
			
		||||
    channel.exchange_declare(exchange=os.environ['PIKA_EXCHANGE_NAME'], exchange_type='fanout')
 | 
			
		||||
    channel.exchange_declare(exchange=os.environ['PIKA_INPUT_EXCHANGE'], exchange_type='fanout')
 | 
			
		||||
 | 
			
		||||
    queue_declare_result = channel.queue_declare(queue='', exclusive=True)
 | 
			
		||||
    queue_name = queue_declare_result.method.queue
 | 
			
		||||
 | 
			
		||||
    channel.queue_bind(exchange=os.environ['PIKA_EXCHANGE_NAME'], queue=queue_name)
 | 
			
		||||
    channel.queue_bind(exchange=os.environ['PIKA_INPUT_EXCHANGE'], queue=queue_name)
 | 
			
		||||
    channel.basic_consume(queue=queue_name, on_message_callback=message_callback, auto_ack=True)
 | 
			
		||||
 | 
			
		||||
    logging.info("Connection complete! Listening to messages...")
 | 
			
		||||
@@ -54,4 +61,4 @@ def main():
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    main()
 | 
			
		||||
    main()
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user