Finished stuff
This commit is contained in:
parent
a45abf5870
commit
fe0552fb09
@ -1,3 +1,4 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
import tempfile
|
import tempfile
|
||||||
import os
|
import os
|
||||||
|
67
cnn_classification_service/magic_doer.py
Normal file
67
cnn_classification_service/magic_doer.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import tempfile
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from cnn_classifier import Classifier
|
||||||
|
|
||||||
|
|
||||||
|
def run_everything(parameters: dict):
|
||||||
|
tag = parameters['tag']
|
||||||
|
|
||||||
|
_, file_path = tempfile.mkstemp(prefix=f"{tag}_", suffix=".wav")
|
||||||
|
_, temp_model_name = tempfile.mkstemp(suffix=".json")
|
||||||
|
_, temp_weights_name = tempfile.mkstemp(suffix=".h5")
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
|
logging.info(f"Downloading sample: {tag}")
|
||||||
|
r = requests.get(f"http://storage-service/object/{tag}")
|
||||||
|
with open(file_path, 'wb') as f:
|
||||||
|
f.write(r.content)
|
||||||
|
|
||||||
|
logging.debug(f"Downloaded sample to {file_path}")
|
||||||
|
|
||||||
|
r = requests.get(f"http://model-service/model/cnn/$default")
|
||||||
|
r.raise_for_status()
|
||||||
|
|
||||||
|
with open(temp_model_name, 'wb') as f:
|
||||||
|
f.write(r.content)
|
||||||
|
|
||||||
|
r = requests.get(f"http://model-service/model/cnn/$default?weights")
|
||||||
|
r.raise_for_status()
|
||||||
|
|
||||||
|
with open(temp_weights_name, 'wb') as f:
|
||||||
|
f.write(r.content)
|
||||||
|
|
||||||
|
# magic happens here
|
||||||
|
classifier = Classifier(temp_model_name, temp_weights_name)
|
||||||
|
results = classifier.predict(file_path)
|
||||||
|
|
||||||
|
finally: # bruuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuuh
|
||||||
|
try:
|
||||||
|
os.remove(temp_model_name)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
os.remove(temp_weights_name)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
os.remove(file_path)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
response = {
|
||||||
|
"tag": tag,
|
||||||
|
"probability": 1.0 if results[0] == 'sturnus' else 0.0,
|
||||||
|
"model": ...
|
||||||
|
}
|
||||||
|
|
||||||
|
logging.info(f"Classification done!")
|
||||||
|
logging.debug(f"Results: {response}")
|
||||||
|
|
||||||
|
return response
|
@ -8,12 +8,19 @@ import json
|
|||||||
from sentry_sdk.integrations.logging import LoggingIntegration
|
from sentry_sdk.integrations.logging import LoggingIntegration
|
||||||
import sentry_sdk
|
import sentry_sdk
|
||||||
|
|
||||||
from cnn_classifier import Classifier
|
from magic_doer import run_everything
|
||||||
|
|
||||||
|
|
||||||
def message_callback(ch, method, properties, body):
|
def message_callback(ch, method, properties, body):
|
||||||
msg = json.loads(body.decode('utf-8'))
|
msg = json.loads(body.decode('utf-8'))
|
||||||
# TODO
|
results = run_everything(msg)
|
||||||
|
|
||||||
|
# TODO: Ez azért elég gettó, de legalább csatlakozik
|
||||||
|
connection = pika.BlockingConnection(pika.connection.URLParameters(os.environ['PIKA_URL']))
|
||||||
|
channel = connection.channel()
|
||||||
|
channel.exchange_declare(exchange=os.environ['PIKA_OUTPUT_EXCHANGE'], exchange_type='fanout')
|
||||||
|
channel.basic_publish(exchange=os.environ['PIKA_OUTPUT_EXCHANGE'], routing_key='classification-result',
|
||||||
|
body=json.dumps(results).encode("utf-8"))
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@ -37,12 +44,12 @@ def main():
|
|||||||
logging.info("Connecting to MQ service...")
|
logging.info("Connecting to MQ service...")
|
||||||
connection = pika.BlockingConnection(pika.connection.URLParameters(os.environ['PIKA_URL']))
|
connection = pika.BlockingConnection(pika.connection.URLParameters(os.environ['PIKA_URL']))
|
||||||
channel = connection.channel()
|
channel = connection.channel()
|
||||||
channel.exchange_declare(exchange=os.environ['PIKA_EXCHANGE_NAME'], exchange_type='fanout')
|
channel.exchange_declare(exchange=os.environ['PIKA_INPUT_EXCHANGE'], exchange_type='fanout')
|
||||||
|
|
||||||
queue_declare_result = channel.queue_declare(queue='', exclusive=True)
|
queue_declare_result = channel.queue_declare(queue='', exclusive=True)
|
||||||
queue_name = queue_declare_result.method.queue
|
queue_name = queue_declare_result.method.queue
|
||||||
|
|
||||||
channel.queue_bind(exchange=os.environ['PIKA_EXCHANGE_NAME'], queue=queue_name)
|
channel.queue_bind(exchange=os.environ['PIKA_INPUT_EXCHANGE'], queue=queue_name)
|
||||||
channel.basic_consume(queue=queue_name, on_message_callback=message_callback, auto_ack=True)
|
channel.basic_consume(queue=queue_name, on_message_callback=message_callback, auto_ack=True)
|
||||||
|
|
||||||
logging.info("Connection complete! Listening to messages...")
|
logging.info("Connection complete! Listening to messages...")
|
||||||
@ -54,4 +61,4 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
Loading…
Reference in New Issue
Block a user