initial commit
This commit is contained in:
		
							
								
								
									
										133
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										133
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@@ -0,0 +1,133 @@
 | 
			
		||||
# Byte-compiled / optimized / DLL files
 | 
			
		||||
__pycache__/
 | 
			
		||||
*.py[cod]
 | 
			
		||||
*$py.class
 | 
			
		||||
 | 
			
		||||
# C extensions
 | 
			
		||||
*.so
 | 
			
		||||
 | 
			
		||||
# Distribution / packaging
 | 
			
		||||
.Python
 | 
			
		||||
build/
 | 
			
		||||
develop-eggs/
 | 
			
		||||
dist/
 | 
			
		||||
downloads/
 | 
			
		||||
eggs/
 | 
			
		||||
.eggs/
 | 
			
		||||
lib/
 | 
			
		||||
lib64/
 | 
			
		||||
parts/
 | 
			
		||||
sdist/
 | 
			
		||||
var/
 | 
			
		||||
wheels/
 | 
			
		||||
pip-wheel-metadata/
 | 
			
		||||
share/python-wheels/
 | 
			
		||||
*.egg-info/
 | 
			
		||||
.installed.cfg
 | 
			
		||||
*.egg
 | 
			
		||||
MANIFEST
 | 
			
		||||
 | 
			
		||||
# PyInstaller
 | 
			
		||||
#  Usually these files are written by a python script from a template
 | 
			
		||||
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
 | 
			
		||||
*.manifest
 | 
			
		||||
*.spec
 | 
			
		||||
 | 
			
		||||
# Installer logs
 | 
			
		||||
pip-log.txt
 | 
			
		||||
pip-delete-this-directory.txt
 | 
			
		||||
 | 
			
		||||
# Unit test / coverage reports
 | 
			
		||||
htmlcov/
 | 
			
		||||
.tox/
 | 
			
		||||
.nox/
 | 
			
		||||
.coverage
 | 
			
		||||
.coverage.*
 | 
			
		||||
.cache
 | 
			
		||||
nosetests.xml
 | 
			
		||||
coverage.xml
 | 
			
		||||
*.cover
 | 
			
		||||
*.py,cover
 | 
			
		||||
.hypothesis/
 | 
			
		||||
.pytest_cache/
 | 
			
		||||
 | 
			
		||||
# Translations
 | 
			
		||||
*.mo
 | 
			
		||||
*.pot
 | 
			
		||||
 | 
			
		||||
# Django stuff:
 | 
			
		||||
*.log
 | 
			
		||||
local_settings.py
 | 
			
		||||
db.sqlite3
 | 
			
		||||
db.sqlite3-journal
 | 
			
		||||
 | 
			
		||||
# Flask stuff:
 | 
			
		||||
instance/
 | 
			
		||||
.webassets-cache
 | 
			
		||||
 | 
			
		||||
# Scrapy stuff:
 | 
			
		||||
.scrapy
 | 
			
		||||
 | 
			
		||||
# Sphinx documentation
 | 
			
		||||
docs/_build/
 | 
			
		||||
 | 
			
		||||
# PyBuilder
 | 
			
		||||
target/
 | 
			
		||||
 | 
			
		||||
# Jupyter Notebook
 | 
			
		||||
.ipynb_checkpoints
 | 
			
		||||
 | 
			
		||||
# IPython
 | 
			
		||||
profile_default/
 | 
			
		||||
ipython_config.py
 | 
			
		||||
 | 
			
		||||
# pyenv
 | 
			
		||||
.python-version
 | 
			
		||||
 | 
			
		||||
# pipenv
 | 
			
		||||
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
 | 
			
		||||
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
 | 
			
		||||
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
 | 
			
		||||
#   install all needed dependencies.
 | 
			
		||||
#Pipfile.lock
 | 
			
		||||
 | 
			
		||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
 | 
			
		||||
__pypackages__/
 | 
			
		||||
 | 
			
		||||
# Celery stuff
 | 
			
		||||
celerybeat-schedule
 | 
			
		||||
celerybeat.pid
 | 
			
		||||
 | 
			
		||||
# SageMath parsed files
 | 
			
		||||
*.sage.py
 | 
			
		||||
 | 
			
		||||
# Environments
 | 
			
		||||
.env
 | 
			
		||||
.venv
 | 
			
		||||
env/
 | 
			
		||||
venv/
 | 
			
		||||
ENV/
 | 
			
		||||
env.bak/
 | 
			
		||||
venv.bak/
 | 
			
		||||
 | 
			
		||||
# Spyder project settings
 | 
			
		||||
.spyderproject
 | 
			
		||||
.spyproject
 | 
			
		||||
 | 
			
		||||
# Rope project settings
 | 
			
		||||
.ropeproject
 | 
			
		||||
 | 
			
		||||
# mkdocs documentation
 | 
			
		||||
/site
 | 
			
		||||
 | 
			
		||||
# mypy
 | 
			
		||||
.mypy_cache/
 | 
			
		||||
.dmypy.json
 | 
			
		||||
dmypy.json
 | 
			
		||||
 | 
			
		||||
# Pyre type checker
 | 
			
		||||
.pyre/
 | 
			
		||||
 | 
			
		||||
#Pycharm
 | 
			
		||||
.idea/
 | 
			
		||||
*.iml
 | 
			
		||||
							
								
								
									
										87
									
								
								cnn_classification_service/cnn_clasifier.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										87
									
								
								cnn_classification_service/cnn_clasifier.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,87 @@
 | 
			
		||||
from typing import Tuple
 | 
			
		||||
import tempfile
 | 
			
		||||
import os
 | 
			
		||||
import os.path
 | 
			
		||||
import shutil
 | 
			
		||||
 | 
			
		||||
import librosa
 | 
			
		||||
import librosa.display
 | 
			
		||||
import numpy
 | 
			
		||||
import matplotlib.pyplot
 | 
			
		||||
from keras.models import model_from_json
 | 
			
		||||
from keras import optimizers
 | 
			
		||||
from keras_preprocessing.image import ImageDataGenerator
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Classifier(object):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, model_filename: str, weights_filename: str):
 | 
			
		||||
        with open(model_filename, 'r') as f:
 | 
			
		||||
            self.loaded_model = model_from_json(f.read())
 | 
			
		||||
 | 
			
		||||
        self.loaded_model.load_weights(weights_filename)
 | 
			
		||||
        self.datagen = ImageDataGenerator(rescale=1. / 255., validation_split=0.25)
 | 
			
		||||
        self.loaded_model.compile(optimizers.rmsprop(lr=0.0005, decay=1e-6), loss="categorical_crossentropy",
 | 
			
		||||
                                  metrics=["accuracy"])
 | 
			
		||||
        self.loaded_model.summary()
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def create_spectrogram(wav_filename: str) -> Tuple[str, str]:
 | 
			
		||||
        matplotlib.pyplot.interactive(False)
 | 
			
		||||
        clip, sample_rate = librosa.load(wav_filename, sr=None)
 | 
			
		||||
        fig = matplotlib.pyplot.figure(figsize=[0.72, 0.72])
 | 
			
		||||
        ax = fig.add_subplot(111)
 | 
			
		||||
        ax.axes.get_xaxis().set_visible(False)
 | 
			
		||||
        ax.axes.get_yaxis().set_visible(False)
 | 
			
		||||
        ax.set_frame_on(False)
 | 
			
		||||
        spectogram = librosa.feature.melspectrogram(y=clip, sr=sample_rate)
 | 
			
		||||
        librosa.display.specshow(librosa.power_to_db(spectogram, ref=numpy.max))
 | 
			
		||||
 | 
			
		||||
        target_dir = tempfile.mkdtemp()
 | 
			
		||||
 | 
			
		||||
        # Change extension to jpg... mert 110% biztos vagyok benne hogy a keras nem bírná beolvasni máshogy
 | 
			
		||||
        file_name = os.path.join(target_dir, "unknown", f"{wav_filename[:-4]}.jpg")
 | 
			
		||||
 | 
			
		||||
        matplotlib.pyplot.savefig(file_name, dpi=400, bbox_inches='tight', pad_inches=0)
 | 
			
		||||
        matplotlib.pyplot.close()
 | 
			
		||||
        fig.clf()
 | 
			
		||||
        matplotlib.pyplot.close(fig)
 | 
			
		||||
        matplotlib.pyplot.close('all')
 | 
			
		||||
 | 
			
		||||
        return target_dir, file_name
 | 
			
		||||
 | 
			
		||||
    def _run_predictor(self, directory: str) -> list:
 | 
			
		||||
        predict_generator = self.datagen.flow_from_directory(
 | 
			
		||||
            directory=directory,
 | 
			
		||||
            batch_size=128,
 | 
			
		||||
            seed=42,
 | 
			
		||||
            shuffle=False,
 | 
			
		||||
            class_mode="categorical",
 | 
			
		||||
            target_size=(64, 64))
 | 
			
		||||
 | 
			
		||||
        prediction = self.loaded_model.predict_generator(predict_generator, steps=1)
 | 
			
		||||
 | 
			
		||||
        predicted_class_indices = numpy.argmax(prediction, axis=1)
 | 
			
		||||
 | 
			
		||||
        labels = {
 | 
			
		||||
            'anser': 0,
 | 
			
		||||
            'columba': 1,
 | 
			
		||||
            'hirundo': 2,
 | 
			
		||||
            'passer': 3,
 | 
			
		||||
            'sturnus': 4,
 | 
			
		||||
            'turdus': 5,
 | 
			
		||||
            'upupa': 6
 | 
			
		||||
        }
 | 
			
		||||
        labels = dict((v, k) for k, v in labels.items())
 | 
			
		||||
 | 
			
		||||
        predictions = [labels[k] for k in predicted_class_indices]
 | 
			
		||||
 | 
			
		||||
        return predictions
 | 
			
		||||
 | 
			
		||||
    def predict(self, wav_filename: str) -> list:
 | 
			
		||||
        directory, _ = self.create_spectrogram(wav_filename)
 | 
			
		||||
 | 
			
		||||
        result = self._run_predictor(directory)
 | 
			
		||||
        shutil.rmtree(directory)  # The image is no longer needed
 | 
			
		||||
 | 
			
		||||
        return result
 | 
			
		||||
							
								
								
									
										57
									
								
								cnn_classification_service/main.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										57
									
								
								cnn_classification_service/main.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,57 @@
 | 
			
		||||
#!/usr/bin/env python3
 | 
			
		||||
import logging
 | 
			
		||||
import os
 | 
			
		||||
import sys
 | 
			
		||||
import pika
 | 
			
		||||
import json
 | 
			
		||||
 | 
			
		||||
from sentry_sdk.integrations.logging import LoggingIntegration
 | 
			
		||||
import sentry_sdk
 | 
			
		||||
 | 
			
		||||
from cnn_classifier import Classifier
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def message_callback(ch, method, properties, body):
 | 
			
		||||
    msg = json.loads(body.decode('utf-8'))
 | 
			
		||||
    # TODO
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
    logging.basicConfig(filename="", format="%(asctime)s - %(name)s [%(levelname)s]: %(message)s",
 | 
			
		||||
                        level=logging.DEBUG if '--debug' in sys.argv else logging.INFO)
 | 
			
		||||
 | 
			
		||||
    SENTRY_DSN = os.environ.get("SENTRY_DSN")
 | 
			
		||||
    if SENTRY_DSN:
 | 
			
		||||
        sentry_logging = LoggingIntegration(
 | 
			
		||||
            level=logging.DEBUG,  # Capture info and above as breadcrumbs
 | 
			
		||||
            event_level=logging.ERROR  # Send errors as events
 | 
			
		||||
        )
 | 
			
		||||
        sentry_sdk.init(
 | 
			
		||||
            dsn=SENTRY_DSN,
 | 
			
		||||
            integrations=[sentry_logging],
 | 
			
		||||
            send_default_pii=True,
 | 
			
		||||
            release=os.environ.get('RELEASE_ID', 'test'),
 | 
			
		||||
            environment=os.environ.get('RELEASEMODE', 'dev')
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    logging.info("Connecting to MQ service...")
 | 
			
		||||
    connection = pika.BlockingConnection(pika.connection.URLParameters(os.environ['PIKA_URL']))
 | 
			
		||||
    channel = connection.channel()
 | 
			
		||||
    channel.exchange_declare(exchange=os.environ['PIKA_EXCHANGE_NAME'], exchange_type='fanout')
 | 
			
		||||
 | 
			
		||||
    queue_declare_result = channel.queue_declare(queue='', exclusive=True)
 | 
			
		||||
    queue_name = queue_declare_result.method.queue
 | 
			
		||||
 | 
			
		||||
    channel.queue_bind(exchange=os.environ['PIKA_EXCHANGE_NAME'], queue=queue_name)
 | 
			
		||||
    channel.basic_consume(queue=queue_name, on_message_callback=message_callback, auto_ack=True)
 | 
			
		||||
 | 
			
		||||
    logging.info("Connection complete! Listening to messages...")
 | 
			
		||||
    try:
 | 
			
		||||
        channel.start_consuming()
 | 
			
		||||
    except KeyboardInterrupt:
 | 
			
		||||
        logging.info("SIGINT Received! Stopping stuff...")
 | 
			
		||||
        channel.stop_consuming()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    main()
 | 
			
		||||
							
								
								
									
										10
									
								
								requirements.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								requirements.txt
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,10 @@
 | 
			
		||||
sentry_sdk
 | 
			
		||||
pika
 | 
			
		||||
requests
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
librosa
 | 
			
		||||
keras
 | 
			
		||||
numpy
 | 
			
		||||
matplotlib
 | 
			
		||||
keras_preprocessing
 | 
			
		||||
		Reference in New Issue
	
	Block a user