diff --git a/cnn_classification_service/cnn_classifier.py b/cnn_classification_service/cnn_classifier.py index fa15c9d..ec61935 100644 --- a/cnn_classification_service/cnn_classifier.py +++ b/cnn_classification_service/cnn_classifier.py @@ -1,3 +1,4 @@ +#!/usr/bin/env python3 from typing import Tuple import tempfile import os diff --git a/cnn_classification_service/magic_doer.py b/cnn_classification_service/magic_doer.py new file mode 100644 index 0000000..980c509 --- /dev/null +++ b/cnn_classification_service/magic_doer.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 +import os +import logging +import tempfile +import requests + +from cnn_classifier import Classifier + + +def run_everything(parameters: dict): + tag = parameters['tag'] + + _, file_path = tempfile.mkstemp(prefix=f"{tag}_", suffix=".wav") + _, temp_model_name = tempfile.mkstemp(suffix=".json") + _, temp_weights_name = tempfile.mkstemp(suffix=".h5") + + try: + + logging.info(f"Downloading sample: {tag}") + r = requests.get(f"http://storage-service/object/{tag}") + with open(file_path, 'wb') as f: + f.write(r.content) + + logging.debug(f"Downloaded sample to {file_path}") + + r = requests.get(f"http://model-service/model/cnn/$default") + r.raise_for_status() + + with open(temp_model_name, 'wb') as f: + f.write(r.content) + + r = requests.get(f"http://model-service/model/cnn/$default?weights") + r.raise_for_status() + + with open(temp_weights_name, 'wb') as f: + f.write(r.content) + + # magic happens here + classifier = Classifier(temp_model_name, temp_weights_name) + results = classifier.predict(file_path) + + finally: # bruuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuh + try: + os.remove(temp_model_name) + except FileNotFoundError: + pass + + try: + os.remove(temp_weights_name) + except FileNotFoundError: + pass + + try: + os.remove(file_path) + except FileNotFoundError: + pass + + response = { + "tag": tag, + "probability": 1.0 if results[0] == 'sturnus' else 0.0, + "model": ... + } + + logging.info(f"Classification done!") + logging.debug(f"Results: {response}") + + return response diff --git a/cnn_classification_service/main.py b/cnn_classification_service/main.py index 4789c6d..9dadc51 100644 --- a/cnn_classification_service/main.py +++ b/cnn_classification_service/main.py @@ -8,12 +8,19 @@ import json from sentry_sdk.integrations.logging import LoggingIntegration import sentry_sdk -from cnn_classifier import Classifier +from magic_doer import run_everything def message_callback(ch, method, properties, body): msg = json.loads(body.decode('utf-8')) - # TODO + results = run_everything(msg) + + # TODO: Ez azért elég gettó, de legalább csatlakozik + connection = pika.BlockingConnection(pika.connection.URLParameters(os.environ['PIKA_URL'])) + channel = connection.channel() + channel.exchange_declare(exchange=os.environ['PIKA_OUTPUT_EXCHANGE'], exchange_type='fanout') + channel.basic_publish(exchange=os.environ['PIKA_OUTPUT_EXCHANGE'], routing_key='classification-result', + body=json.dumps(results).encode("utf-8")) def main(): @@ -37,12 +44,12 @@ def main(): logging.info("Connecting to MQ service...") connection = pika.BlockingConnection(pika.connection.URLParameters(os.environ['PIKA_URL'])) channel = connection.channel() - channel.exchange_declare(exchange=os.environ['PIKA_EXCHANGE_NAME'], exchange_type='fanout') + channel.exchange_declare(exchange=os.environ['PIKA_INPUT_EXCHANGE'], exchange_type='fanout') queue_declare_result = channel.queue_declare(queue='', exclusive=True) queue_name = queue_declare_result.method.queue - channel.queue_bind(exchange=os.environ['PIKA_EXCHANGE_NAME'], queue=queue_name) + channel.queue_bind(exchange=os.environ['PIKA_INPUT_EXCHANGE'], queue=queue_name) channel.basic_consume(queue=queue_name, on_message_callback=message_callback, auto_ack=True) logging.info("Connection complete! Listening to messages...") @@ -54,4 +61,4 @@ def main(): if __name__ == '__main__': - main() \ No newline at end of file + main()