From 2238fe08f969d529d82d093c403309b8ccab8f90 Mon Sep 17 00:00:00 2001 From: marcsello Date: Thu, 1 Oct 2020 20:15:24 +0200 Subject: [PATCH] Updated local Ai internals --- src/preprocessor/soundpreprocessor.py | 32 ++++++++++++++++++--------- src/utils/config.py | 1 - 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/src/preprocessor/soundpreprocessor.py b/src/preprocessor/soundpreprocessor.py index 397fe07..6ff4bac 100644 --- a/src/preprocessor/soundpreprocessor.py +++ b/src/preprocessor/soundpreprocessor.py @@ -11,6 +11,8 @@ from pyAudioAnalysis import audioBasicIO from pyAudioAnalysis import MidTermFeatures import numpy +from birbnetes_iot_platform_raspberry import BirbnetesIoTPlatformStatusDriver + """ Abstract base class for Sender """ @@ -28,11 +30,11 @@ class SoundPreProcessor(AbcPreProcessor): def __init__(self): logging.info("Downloading current model...") - _, self._temp_model_name = tempfile.mkstemp() + temp_model_handle, self._temp_model_name = tempfile.mkstemp() self._temp_means_name = self._temp_model_name + "MEANS" logging.debug("Fetching model info...") - + BirbnetesIoTPlatformStatusDriver.enqueue_pattern('green', [1]) if config.SVM_MODEL_ID: model_id_to_get = config.SVM_MODEL_ID else: @@ -44,13 +46,15 @@ class SoundPreProcessor(AbcPreProcessor): self._model_details = r.json() logging.debug("Downloading model...") + BirbnetesIoTPlatformStatusDriver.enqueue_pattern('green', [1]) r = requests.get(f"{config.API_URL}/model/svm/{self._model_details['id']}") r.raise_for_status() - with open(self._temp_model_name, 'wb') as f: + with open(temp_model_handle, 'wb') as f: # bruhtastic f.write(r.content) logging.debug("Downloading MEANS...") + BirbnetesIoTPlatformStatusDriver.enqueue_pattern('green', [1]) r = requests.get(f"{config.API_URL}/model/svm/{self._model_details['id']}?means") r.raise_for_status() @@ -69,9 +73,14 @@ class SoundPreProcessor(AbcPreProcessor): self._mid_window, self._mid_step, self._short_window, \ self._short_step, self._compute_beat = load_model(self._temp_model_name) + target_class_name = self._model_details['target_class_name'] + logging.info("The loaded model contains the following classes: " + ", ".join(self._classes)) - if config.SVM_TARGET_CLASS_NAME not in self._classes: - raise Exception(f"The specified target class {config.SVM_TARGET_CLASS_NAME} is not in the possible classes") + if target_class_name not in self._classes: + raise Exception( + f"The specified target class {target_class_name} is not in the possible classes (Wrong model info?)") + + self._target_id = self._classes.index(target_class_name) def preprocesssignal(self, file_path: str) -> bool: """ @@ -107,16 +116,17 @@ class SoundPreProcessor(AbcPreProcessor): logging.debug("Running classification...") - target_id = self._classes.index(config.SVM_TARGET_CLASS_NAME) # Might raise ValueError - feature_vector = (mid_features - self._mean) / self._std - class_id, probability = classifier_wrapper(self._classifier, self._model_details['type'].lower(), - feature_vector) + class_id, probability = classifier_wrapper( + self._classifier, self._model_details['type'].lower(), feature_vector + ) class_id = int(class_id) # faszom - logging.debug(f"Sample {file_path} identified as {self._classes[class_id]} with the probablility of {probability[class_id]}") + logging.debug( + f"Sample {file_path} identified as {self._classes[class_id]} with the probablility of {probability[class_id]}" + ) - return bool(class_id == target_id) + return bool(class_id == self._target_id) def __del__(self): try: diff --git a/src/utils/config.py b/src/utils/config.py index 67d5328..43d22cd 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -27,7 +27,6 @@ MQTT_PORT = int(os.getenv("GUARD_MQTT_PORT", 1883)) MQTT_USERNAME = os.getenv("GUARD_MQTT_USERNAME", None) MQTT_PASSWORD = os.getenv("GUARD_MQTT_PASSWORD", None) -SVM_TARGET_CLASS_NAME = os.environ.get("SVM_TARGET_CLASS_NAME", "chirp") SVM_MODEL_ID = os.environ.get("SVM_MODEL_ID") API_URL = os.environ.get("API_URL", "http://localhost:8080") \ No newline at end of file