Finished stuff
This commit is contained in:
		| @@ -1,3 +1,4 @@ | |||||||
|  | #!/usr/bin/env python3 | ||||||
| from typing import Tuple | from typing import Tuple | ||||||
| import tempfile | import tempfile | ||||||
| import os | import os | ||||||
|   | |||||||
							
								
								
									
										67
									
								
								cnn_classification_service/magic_doer.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										67
									
								
								cnn_classification_service/magic_doer.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,67 @@ | |||||||
|  | #!/usr/bin/env python3 | ||||||
|  | import os | ||||||
|  | import logging | ||||||
|  | import tempfile | ||||||
|  | import requests | ||||||
|  |  | ||||||
|  | from cnn_classifier import Classifier | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def run_everything(parameters: dict): | ||||||
|  |     tag = parameters['tag'] | ||||||
|  |  | ||||||
|  |     _, file_path = tempfile.mkstemp(prefix=f"{tag}_", suffix=".wav") | ||||||
|  |     _, temp_model_name = tempfile.mkstemp(suffix=".json") | ||||||
|  |     _, temp_weights_name = tempfile.mkstemp(suffix=".h5") | ||||||
|  |  | ||||||
|  |     try: | ||||||
|  |  | ||||||
|  |         logging.info(f"Downloading sample: {tag}") | ||||||
|  |         r = requests.get(f"http://storage-service/object/{tag}") | ||||||
|  |         with open(file_path, 'wb') as f: | ||||||
|  |             f.write(r.content) | ||||||
|  |  | ||||||
|  |         logging.debug(f"Downloaded sample to {file_path}") | ||||||
|  |  | ||||||
|  |         r = requests.get(f"http://model-service/model/cnn/$default") | ||||||
|  |         r.raise_for_status() | ||||||
|  |  | ||||||
|  |         with open(temp_model_name, 'wb') as f: | ||||||
|  |             f.write(r.content) | ||||||
|  |  | ||||||
|  |         r = requests.get(f"http://model-service/model/cnn/$default?weights") | ||||||
|  |         r.raise_for_status() | ||||||
|  |  | ||||||
|  |         with open(temp_weights_name, 'wb') as f: | ||||||
|  |             f.write(r.content) | ||||||
|  |  | ||||||
|  |         # magic happens here | ||||||
|  |         classifier = Classifier(temp_model_name, temp_weights_name) | ||||||
|  |         results = classifier.predict(file_path) | ||||||
|  |  | ||||||
|  |     finally:  # bruuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuh | ||||||
|  |         try: | ||||||
|  |             os.remove(temp_model_name) | ||||||
|  |         except FileNotFoundError: | ||||||
|  |             pass | ||||||
|  |  | ||||||
|  |         try: | ||||||
|  |             os.remove(temp_weights_name) | ||||||
|  |         except FileNotFoundError: | ||||||
|  |             pass | ||||||
|  |  | ||||||
|  |         try: | ||||||
|  |             os.remove(file_path) | ||||||
|  |         except FileNotFoundError: | ||||||
|  |             pass | ||||||
|  |  | ||||||
|  |     response = { | ||||||
|  |         "tag": tag, | ||||||
|  |         "probability": 1.0 if results[0] == 'sturnus' else 0.0, | ||||||
|  |         "model": ... | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     logging.info(f"Classification done!") | ||||||
|  |     logging.debug(f"Results: {response}") | ||||||
|  |  | ||||||
|  |     return response | ||||||
| @@ -8,12 +8,19 @@ import json | |||||||
| from sentry_sdk.integrations.logging import LoggingIntegration | from sentry_sdk.integrations.logging import LoggingIntegration | ||||||
| import sentry_sdk | import sentry_sdk | ||||||
|  |  | ||||||
| from cnn_classifier import Classifier | from magic_doer import run_everything | ||||||
|  |  | ||||||
|  |  | ||||||
| def message_callback(ch, method, properties, body): | def message_callback(ch, method, properties, body): | ||||||
|     msg = json.loads(body.decode('utf-8')) |     msg = json.loads(body.decode('utf-8')) | ||||||
|     # TODO |     results = run_everything(msg) | ||||||
|  |  | ||||||
|  |     # TODO: Ez azért elég gettó, de legalább csatlakozik | ||||||
|  |     connection = pika.BlockingConnection(pika.connection.URLParameters(os.environ['PIKA_URL'])) | ||||||
|  |     channel = connection.channel() | ||||||
|  |     channel.exchange_declare(exchange=os.environ['PIKA_OUTPUT_EXCHANGE'], exchange_type='fanout') | ||||||
|  |     channel.basic_publish(exchange=os.environ['PIKA_OUTPUT_EXCHANGE'], routing_key='classification-result', | ||||||
|  |                           body=json.dumps(results).encode("utf-8")) | ||||||
|  |  | ||||||
|  |  | ||||||
| def main(): | def main(): | ||||||
| @@ -37,12 +44,12 @@ def main(): | |||||||
|     logging.info("Connecting to MQ service...") |     logging.info("Connecting to MQ service...") | ||||||
|     connection = pika.BlockingConnection(pika.connection.URLParameters(os.environ['PIKA_URL'])) |     connection = pika.BlockingConnection(pika.connection.URLParameters(os.environ['PIKA_URL'])) | ||||||
|     channel = connection.channel() |     channel = connection.channel() | ||||||
|     channel.exchange_declare(exchange=os.environ['PIKA_EXCHANGE_NAME'], exchange_type='fanout') |     channel.exchange_declare(exchange=os.environ['PIKA_INPUT_EXCHANGE'], exchange_type='fanout') | ||||||
|  |  | ||||||
|     queue_declare_result = channel.queue_declare(queue='', exclusive=True) |     queue_declare_result = channel.queue_declare(queue='', exclusive=True) | ||||||
|     queue_name = queue_declare_result.method.queue |     queue_name = queue_declare_result.method.queue | ||||||
|  |  | ||||||
|     channel.queue_bind(exchange=os.environ['PIKA_EXCHANGE_NAME'], queue=queue_name) |     channel.queue_bind(exchange=os.environ['PIKA_INPUT_EXCHANGE'], queue=queue_name) | ||||||
|     channel.basic_consume(queue=queue_name, on_message_callback=message_callback, auto_ack=True) |     channel.basic_consume(queue=queue_name, on_message_callback=message_callback, auto_ack=True) | ||||||
|  |  | ||||||
|     logging.info("Connection complete! Listening to messages...") |     logging.info("Connection complete! Listening to messages...") | ||||||
| @@ -54,4 +61,4 @@ def main(): | |||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     main() |     main() | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user