From 09c991baf3c04624d5f92c24d7b30afd682fa21b Mon Sep 17 00:00:00 2001 From: marcsello Date: Fri, 25 Sep 2020 13:48:11 +0200 Subject: [PATCH] Implemented AI stuff --- src/preprocessor/abcpreprocessor.py | 2 +- src/preprocessor/soundpreprocessor.py | 108 +++++++++++++++++++++----- 2 files changed, 90 insertions(+), 20 deletions(-) diff --git a/src/preprocessor/abcpreprocessor.py b/src/preprocessor/abcpreprocessor.py index 6c9f327..28633ac 100644 --- a/src/preprocessor/abcpreprocessor.py +++ b/src/preprocessor/abcpreprocessor.py @@ -16,7 +16,7 @@ class AbcPreProcessor(ABC): Abstract base class PreProcessor. Responsible for manipulating input data from a sensor. """ @abstractmethod - def preprocesssignal(self, signal): + def preprocesssignal(self, file_path) -> bool: """ Preprocess a signal. :return: diff --git a/src/preprocessor/soundpreprocessor.py b/src/preprocessor/soundpreprocessor.py index 5f828aa..d362561 100644 --- a/src/preprocessor/soundpreprocessor.py +++ b/src/preprocessor/soundpreprocessor.py @@ -1,9 +1,15 @@ #!/usr/bin/env python3 import requests -from pyAudioAnalysis import audioTrainTest as aT -from pyAudioAnalysis.audioTrainTest import load_model, load_model_knn +from pyAudioAnalysis.audioTrainTest import load_model, load_model_knn, classifier_wrapper from utils import config from .abcpreprocessor import AbcPreProcessor +import tempfile +import os +import logging + +from pyAudioAnalysis import audioBasicIO +from pyAudioAnalysis import MidTermFeatures +import numpy """ Abstract base class for Sender @@ -20,27 +26,91 @@ class SoundPreProcessor(AbcPreProcessor): SoundPreProcessor class, responsible for detecting birb chirps in sound sample. """ - def preprocesssignal(self, signal: str) -> bool: - """ - Classify a sound sample. - :param signal: Access path of the sound sample up for processing. - :return: - """ - # TODO: Dynamic model injection? - r = requests.get(f"http://model-service/model/{config.MODEL_ID}/details") + def __init__(self): + logging.info("Downloading current model...") + _, self._temp_model_name = tempfile.mkstemp() + self._temp_means_name = self._temp_model_name + "MEANS" + + logging.debug("Fetching model info...") + r = requests.get(f"{config.API_URL}/model/$default/details") r.raise_for_status() - model_details = r.json() + self._model_details = r.json() - if model_details['type'] == 'knn': - classifier, mean, std, classes, mid_window, mid_step, short_window, short_step, compute_beat \ - = load_model_knn(config.MODEL_ID + "MEANS") + logging.debug("Downloading model...") + r = requests.get(f"{config.API_URL}/model/{self._model_details['id']}") + r.raise_for_status() + + with open(self._temp_model_name, 'wb') as f: + f.write(r.content) + + logging.debug("Downloading MEANS...") + r = requests.get(f"{config.API_URL}/model/{self._model_details['id']}?means") + r.raise_for_status() + + with open(self._temp_means_name, 'wb') as f: + f.write(r.content) + + logging.info("Loading current model...") + + if self._model_details['type'] == 'knn': + self._classifier, self._mean, self._std, self._classes, \ + self._mid_window, self._mid_step, self._short_window, \ + self._short_step, self._compute_beat = load_model_knn(self._temp_model_name) else: - classifier, mean, std, classes, mid_window, mid_step, short_window, short_step, compute_beat \ - = load_model(config.MODEL_ID + "MEANS") + self._classifier, self._mean, self._std, self._classes, \ + self._mid_window, self._mid_step, self._short_window, \ + self._short_step, self._compute_beat = load_model(self._temp_model_name) - target_id = classes.index(config.TARGET_NAME) + def preprocesssignal(self, file_path: str) -> bool: + """ + Classify a sound sample. + :param file_path: Access path of the sound sample up for processing. + :return: + """ + logging.info("Running extraction...") - class_id, probability = aT.file_classification(signal, config.MODEL_ID, "svm") - return class_id == target_id + sampling_rate, signal = audioBasicIO.read_audio_file(file_path) + signal = audioBasicIO.stereo_to_mono(signal) + + if sampling_rate == 0: + raise Exception("Could not read the file properly: Sampling rate zero") + + if signal.shape[0] / float(sampling_rate) <= self._mid_window: + raise Exception("Could not read the file properly: Signal shape is not good") + + # feature extraction: + mid_features, s, _ = \ + MidTermFeatures.mid_feature_extraction(signal, sampling_rate, + self._mid_window * sampling_rate, + self._mid_step * sampling_rate, + round(sampling_rate * self._short_window), + round(sampling_rate * self._short_step)) + + # long term averaging of mid-term statistics + mid_features = mid_features.mean(axis=1) + if self._compute_beat: + beat, beat_conf = MidTermFeatures.beat_extraction(s, self._short_step) + mid_features = numpy.append(mid_features, beat) + mid_features = numpy.append(mid_features, beat_conf) + + logging.info("Running classification...") + + target_id = self._classes.index('chirp') # Might raise ValueError + + feature_vector = (mid_features - self._mean) / self._std + class_id, probability = classifier_wrapper(self._classifier, self._model_details['type'], feature_vector) + + return bool(class_id == target_id) + + def __del__(self): + try: + os.remove(self._temp_model_name) + except FileNotFoundError: + pass + + try: + os.remove(self._temp_means_name) + except FileNotFoundError: + pass