Updated local Ai internals
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
Pünkösd Marcell 2020-10-01 20:15:24 +02:00
parent a223c8e95e
commit 2238fe08f9
2 changed files with 21 additions and 12 deletions

View File

@ -11,6 +11,8 @@ from pyAudioAnalysis import audioBasicIO
from pyAudioAnalysis import MidTermFeatures from pyAudioAnalysis import MidTermFeatures
import numpy import numpy
from birbnetes_iot_platform_raspberry import BirbnetesIoTPlatformStatusDriver
""" """
Abstract base class for Sender Abstract base class for Sender
""" """
@ -28,11 +30,11 @@ class SoundPreProcessor(AbcPreProcessor):
def __init__(self): def __init__(self):
logging.info("Downloading current model...") logging.info("Downloading current model...")
_, self._temp_model_name = tempfile.mkstemp() temp_model_handle, self._temp_model_name = tempfile.mkstemp()
self._temp_means_name = self._temp_model_name + "MEANS" self._temp_means_name = self._temp_model_name + "MEANS"
logging.debug("Fetching model info...") logging.debug("Fetching model info...")
BirbnetesIoTPlatformStatusDriver.enqueue_pattern('green', [1])
if config.SVM_MODEL_ID: if config.SVM_MODEL_ID:
model_id_to_get = config.SVM_MODEL_ID model_id_to_get = config.SVM_MODEL_ID
else: else:
@ -44,13 +46,15 @@ class SoundPreProcessor(AbcPreProcessor):
self._model_details = r.json() self._model_details = r.json()
logging.debug("Downloading model...") logging.debug("Downloading model...")
BirbnetesIoTPlatformStatusDriver.enqueue_pattern('green', [1])
r = requests.get(f"{config.API_URL}/model/svm/{self._model_details['id']}") r = requests.get(f"{config.API_URL}/model/svm/{self._model_details['id']}")
r.raise_for_status() r.raise_for_status()
with open(self._temp_model_name, 'wb') as f: with open(temp_model_handle, 'wb') as f: # bruhtastic
f.write(r.content) f.write(r.content)
logging.debug("Downloading MEANS...") logging.debug("Downloading MEANS...")
BirbnetesIoTPlatformStatusDriver.enqueue_pattern('green', [1])
r = requests.get(f"{config.API_URL}/model/svm/{self._model_details['id']}?means") r = requests.get(f"{config.API_URL}/model/svm/{self._model_details['id']}?means")
r.raise_for_status() r.raise_for_status()
@ -69,9 +73,14 @@ class SoundPreProcessor(AbcPreProcessor):
self._mid_window, self._mid_step, self._short_window, \ self._mid_window, self._mid_step, self._short_window, \
self._short_step, self._compute_beat = load_model(self._temp_model_name) self._short_step, self._compute_beat = load_model(self._temp_model_name)
target_class_name = self._model_details['target_class_name']
logging.info("The loaded model contains the following classes: " + ", ".join(self._classes)) logging.info("The loaded model contains the following classes: " + ", ".join(self._classes))
if config.SVM_TARGET_CLASS_NAME not in self._classes: if target_class_name not in self._classes:
raise Exception(f"The specified target class {config.SVM_TARGET_CLASS_NAME} is not in the possible classes") raise Exception(
f"The specified target class {target_class_name} is not in the possible classes (Wrong model info?)")
self._target_id = self._classes.index(target_class_name)
def preprocesssignal(self, file_path: str) -> bool: def preprocesssignal(self, file_path: str) -> bool:
""" """
@ -107,16 +116,17 @@ class SoundPreProcessor(AbcPreProcessor):
logging.debug("Running classification...") logging.debug("Running classification...")
target_id = self._classes.index(config.SVM_TARGET_CLASS_NAME) # Might raise ValueError
feature_vector = (mid_features - self._mean) / self._std feature_vector = (mid_features - self._mean) / self._std
class_id, probability = classifier_wrapper(self._classifier, self._model_details['type'].lower(), class_id, probability = classifier_wrapper(
feature_vector) self._classifier, self._model_details['type'].lower(), feature_vector
)
class_id = int(class_id) # faszom class_id = int(class_id) # faszom
logging.debug(f"Sample {file_path} identified as {self._classes[class_id]} with the probablility of {probability[class_id]}") logging.debug(
f"Sample {file_path} identified as {self._classes[class_id]} with the probablility of {probability[class_id]}"
)
return bool(class_id == target_id) return bool(class_id == self._target_id)
def __del__(self): def __del__(self):
try: try:

View File

@ -27,7 +27,6 @@ MQTT_PORT = int(os.getenv("GUARD_MQTT_PORT", 1883))
MQTT_USERNAME = os.getenv("GUARD_MQTT_USERNAME", None) MQTT_USERNAME = os.getenv("GUARD_MQTT_USERNAME", None)
MQTT_PASSWORD = os.getenv("GUARD_MQTT_PASSWORD", None) MQTT_PASSWORD = os.getenv("GUARD_MQTT_PASSWORD", None)
SVM_TARGET_CLASS_NAME = os.environ.get("SVM_TARGET_CLASS_NAME", "chirp")
SVM_MODEL_ID = os.environ.get("SVM_MODEL_ID") SVM_MODEL_ID = os.environ.get("SVM_MODEL_ID")
API_URL = os.environ.get("API_URL", "http://localhost:8080") API_URL = os.environ.get("API_URL", "http://localhost:8080")