From fcef542a7c8d3eb5967bafdddf00a8f7db3f0707 Mon Sep 17 00:00:00 2001 From: marcsello Date: Mon, 27 Jul 2020 17:58:48 +0200 Subject: [PATCH] initial commit --- .gitignore | 133 ++++++++++++++++++++ cnn_classification_service/cnn_clasifier.py | 87 +++++++++++++ cnn_classification_service/main.py | 57 +++++++++ requirements.txt | 10 ++ 4 files changed, 287 insertions(+) create mode 100644 .gitignore create mode 100644 cnn_classification_service/cnn_clasifier.py create mode 100644 cnn_classification_service/main.py create mode 100644 requirements.txt diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..cf3c132 --- /dev/null +++ b/.gitignore @@ -0,0 +1,133 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +#Pycharm +.idea/ +*.iml diff --git a/cnn_classification_service/cnn_clasifier.py b/cnn_classification_service/cnn_clasifier.py new file mode 100644 index 0000000..fa15c9d --- /dev/null +++ b/cnn_classification_service/cnn_clasifier.py @@ -0,0 +1,87 @@ +from typing import Tuple +import tempfile +import os +import os.path +import shutil + +import librosa +import librosa.display +import numpy +import matplotlib.pyplot +from keras.models import model_from_json +from keras import optimizers +from keras_preprocessing.image import ImageDataGenerator + + +class Classifier(object): + + def __init__(self, model_filename: str, weights_filename: str): + with open(model_filename, 'r') as f: + self.loaded_model = model_from_json(f.read()) + + self.loaded_model.load_weights(weights_filename) + self.datagen = ImageDataGenerator(rescale=1. / 255., validation_split=0.25) + self.loaded_model.compile(optimizers.rmsprop(lr=0.0005, decay=1e-6), loss="categorical_crossentropy", + metrics=["accuracy"]) + self.loaded_model.summary() + + @staticmethod + def create_spectrogram(wav_filename: str) -> Tuple[str, str]: + matplotlib.pyplot.interactive(False) + clip, sample_rate = librosa.load(wav_filename, sr=None) + fig = matplotlib.pyplot.figure(figsize=[0.72, 0.72]) + ax = fig.add_subplot(111) + ax.axes.get_xaxis().set_visible(False) + ax.axes.get_yaxis().set_visible(False) + ax.set_frame_on(False) + spectogram = librosa.feature.melspectrogram(y=clip, sr=sample_rate) + librosa.display.specshow(librosa.power_to_db(spectogram, ref=numpy.max)) + + target_dir = tempfile.mkdtemp() + + # Change extension to jpg... mert 110% biztos vagyok benne hogy a keras nem bírná beolvasni máshogy + file_name = os.path.join(target_dir, "unknown", f"{wav_filename[:-4]}.jpg") + + matplotlib.pyplot.savefig(file_name, dpi=400, bbox_inches='tight', pad_inches=0) + matplotlib.pyplot.close() + fig.clf() + matplotlib.pyplot.close(fig) + matplotlib.pyplot.close('all') + + return target_dir, file_name + + def _run_predictor(self, directory: str) -> list: + predict_generator = self.datagen.flow_from_directory( + directory=directory, + batch_size=128, + seed=42, + shuffle=False, + class_mode="categorical", + target_size=(64, 64)) + + prediction = self.loaded_model.predict_generator(predict_generator, steps=1) + + predicted_class_indices = numpy.argmax(prediction, axis=1) + + labels = { + 'anser': 0, + 'columba': 1, + 'hirundo': 2, + 'passer': 3, + 'sturnus': 4, + 'turdus': 5, + 'upupa': 6 + } + labels = dict((v, k) for k, v in labels.items()) + + predictions = [labels[k] for k in predicted_class_indices] + + return predictions + + def predict(self, wav_filename: str) -> list: + directory, _ = self.create_spectrogram(wav_filename) + + result = self._run_predictor(directory) + shutil.rmtree(directory) # The image is no longer needed + + return result diff --git a/cnn_classification_service/main.py b/cnn_classification_service/main.py new file mode 100644 index 0000000..4789c6d --- /dev/null +++ b/cnn_classification_service/main.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +import logging +import os +import sys +import pika +import json + +from sentry_sdk.integrations.logging import LoggingIntegration +import sentry_sdk + +from cnn_classifier import Classifier + + +def message_callback(ch, method, properties, body): + msg = json.loads(body.decode('utf-8')) + # TODO + + +def main(): + logging.basicConfig(filename="", format="%(asctime)s - %(name)s [%(levelname)s]: %(message)s", + level=logging.DEBUG if '--debug' in sys.argv else logging.INFO) + + SENTRY_DSN = os.environ.get("SENTRY_DSN") + if SENTRY_DSN: + sentry_logging = LoggingIntegration( + level=logging.DEBUG, # Capture info and above as breadcrumbs + event_level=logging.ERROR # Send errors as events + ) + sentry_sdk.init( + dsn=SENTRY_DSN, + integrations=[sentry_logging], + send_default_pii=True, + release=os.environ.get('RELEASE_ID', 'test'), + environment=os.environ.get('RELEASEMODE', 'dev') + ) + + logging.info("Connecting to MQ service...") + connection = pika.BlockingConnection(pika.connection.URLParameters(os.environ['PIKA_URL'])) + channel = connection.channel() + channel.exchange_declare(exchange=os.environ['PIKA_EXCHANGE_NAME'], exchange_type='fanout') + + queue_declare_result = channel.queue_declare(queue='', exclusive=True) + queue_name = queue_declare_result.method.queue + + channel.queue_bind(exchange=os.environ['PIKA_EXCHANGE_NAME'], queue=queue_name) + channel.basic_consume(queue=queue_name, on_message_callback=message_callback, auto_ack=True) + + logging.info("Connection complete! Listening to messages...") + try: + channel.start_consuming() + except KeyboardInterrupt: + logging.info("SIGINT Received! Stopping stuff...") + channel.stop_consuming() + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..f544a55 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,10 @@ +sentry_sdk +pika +requests + + +librosa +keras +numpy +matplotlib +keras_preprocessing \ No newline at end of file