Updated local Ai internals
continuous-integration/drone/push Build is passing Details

This commit is contained in:
Pünkösd Marcell 2020-10-01 20:15:24 +02:00
parent a223c8e95e
commit 2238fe08f9
2 changed files with 21 additions and 12 deletions

View File

@ -11,6 +11,8 @@ from pyAudioAnalysis import audioBasicIO
from pyAudioAnalysis import MidTermFeatures
import numpy
from birbnetes_iot_platform_raspberry import BirbnetesIoTPlatformStatusDriver
"""
Abstract base class for Sender
"""
@ -28,11 +30,11 @@ class SoundPreProcessor(AbcPreProcessor):
def __init__(self):
logging.info("Downloading current model...")
_, self._temp_model_name = tempfile.mkstemp()
temp_model_handle, self._temp_model_name = tempfile.mkstemp()
self._temp_means_name = self._temp_model_name + "MEANS"
logging.debug("Fetching model info...")
BirbnetesIoTPlatformStatusDriver.enqueue_pattern('green', [1])
if config.SVM_MODEL_ID:
model_id_to_get = config.SVM_MODEL_ID
else:
@ -44,13 +46,15 @@ class SoundPreProcessor(AbcPreProcessor):
self._model_details = r.json()
logging.debug("Downloading model...")
BirbnetesIoTPlatformStatusDriver.enqueue_pattern('green', [1])
r = requests.get(f"{config.API_URL}/model/svm/{self._model_details['id']}")
r.raise_for_status()
with open(self._temp_model_name, 'wb') as f:
with open(temp_model_handle, 'wb') as f: # bruhtastic
f.write(r.content)
logging.debug("Downloading MEANS...")
BirbnetesIoTPlatformStatusDriver.enqueue_pattern('green', [1])
r = requests.get(f"{config.API_URL}/model/svm/{self._model_details['id']}?means")
r.raise_for_status()
@ -69,9 +73,14 @@ class SoundPreProcessor(AbcPreProcessor):
self._mid_window, self._mid_step, self._short_window, \
self._short_step, self._compute_beat = load_model(self._temp_model_name)
target_class_name = self._model_details['target_class_name']
logging.info("The loaded model contains the following classes: " + ", ".join(self._classes))
if config.SVM_TARGET_CLASS_NAME not in self._classes:
raise Exception(f"The specified target class {config.SVM_TARGET_CLASS_NAME} is not in the possible classes")
if target_class_name not in self._classes:
raise Exception(
f"The specified target class {target_class_name} is not in the possible classes (Wrong model info?)")
self._target_id = self._classes.index(target_class_name)
def preprocesssignal(self, file_path: str) -> bool:
"""
@ -107,16 +116,17 @@ class SoundPreProcessor(AbcPreProcessor):
logging.debug("Running classification...")
target_id = self._classes.index(config.SVM_TARGET_CLASS_NAME) # Might raise ValueError
feature_vector = (mid_features - self._mean) / self._std
class_id, probability = classifier_wrapper(self._classifier, self._model_details['type'].lower(),
feature_vector)
class_id, probability = classifier_wrapper(
self._classifier, self._model_details['type'].lower(), feature_vector
)
class_id = int(class_id) # faszom
logging.debug(f"Sample {file_path} identified as {self._classes[class_id]} with the probablility of {probability[class_id]}")
logging.debug(
f"Sample {file_path} identified as {self._classes[class_id]} with the probablility of {probability[class_id]}"
)
return bool(class_id == target_id)
return bool(class_id == self._target_id)
def __del__(self):
try:

View File

@ -27,7 +27,6 @@ MQTT_PORT = int(os.getenv("GUARD_MQTT_PORT", 1883))
MQTT_USERNAME = os.getenv("GUARD_MQTT_USERNAME", None)
MQTT_PASSWORD = os.getenv("GUARD_MQTT_PASSWORD", None)
SVM_TARGET_CLASS_NAME = os.environ.get("SVM_TARGET_CLASS_NAME", "chirp")
SVM_MODEL_ID = os.environ.get("SVM_MODEL_ID")
API_URL = os.environ.get("API_URL", "http://localhost:8080")