Updated local Ai internals
All checks were successful
continuous-integration/drone/push Build is passing
All checks were successful
continuous-integration/drone/push Build is passing
This commit is contained in:
parent
a223c8e95e
commit
2238fe08f9
@ -11,6 +11,8 @@ from pyAudioAnalysis import audioBasicIO
|
||||
from pyAudioAnalysis import MidTermFeatures
|
||||
import numpy
|
||||
|
||||
from birbnetes_iot_platform_raspberry import BirbnetesIoTPlatformStatusDriver
|
||||
|
||||
"""
|
||||
Abstract base class for Sender
|
||||
"""
|
||||
@ -28,11 +30,11 @@ class SoundPreProcessor(AbcPreProcessor):
|
||||
|
||||
def __init__(self):
|
||||
logging.info("Downloading current model...")
|
||||
_, self._temp_model_name = tempfile.mkstemp()
|
||||
temp_model_handle, self._temp_model_name = tempfile.mkstemp()
|
||||
self._temp_means_name = self._temp_model_name + "MEANS"
|
||||
|
||||
logging.debug("Fetching model info...")
|
||||
|
||||
BirbnetesIoTPlatformStatusDriver.enqueue_pattern('green', [1])
|
||||
if config.SVM_MODEL_ID:
|
||||
model_id_to_get = config.SVM_MODEL_ID
|
||||
else:
|
||||
@ -44,13 +46,15 @@ class SoundPreProcessor(AbcPreProcessor):
|
||||
self._model_details = r.json()
|
||||
|
||||
logging.debug("Downloading model...")
|
||||
BirbnetesIoTPlatformStatusDriver.enqueue_pattern('green', [1])
|
||||
r = requests.get(f"{config.API_URL}/model/svm/{self._model_details['id']}")
|
||||
r.raise_for_status()
|
||||
|
||||
with open(self._temp_model_name, 'wb') as f:
|
||||
with open(temp_model_handle, 'wb') as f: # bruhtastic
|
||||
f.write(r.content)
|
||||
|
||||
logging.debug("Downloading MEANS...")
|
||||
BirbnetesIoTPlatformStatusDriver.enqueue_pattern('green', [1])
|
||||
r = requests.get(f"{config.API_URL}/model/svm/{self._model_details['id']}?means")
|
||||
r.raise_for_status()
|
||||
|
||||
@ -69,9 +73,14 @@ class SoundPreProcessor(AbcPreProcessor):
|
||||
self._mid_window, self._mid_step, self._short_window, \
|
||||
self._short_step, self._compute_beat = load_model(self._temp_model_name)
|
||||
|
||||
target_class_name = self._model_details['target_class_name']
|
||||
|
||||
logging.info("The loaded model contains the following classes: " + ", ".join(self._classes))
|
||||
if config.SVM_TARGET_CLASS_NAME not in self._classes:
|
||||
raise Exception(f"The specified target class {config.SVM_TARGET_CLASS_NAME} is not in the possible classes")
|
||||
if target_class_name not in self._classes:
|
||||
raise Exception(
|
||||
f"The specified target class {target_class_name} is not in the possible classes (Wrong model info?)")
|
||||
|
||||
self._target_id = self._classes.index(target_class_name)
|
||||
|
||||
def preprocesssignal(self, file_path: str) -> bool:
|
||||
"""
|
||||
@ -107,16 +116,17 @@ class SoundPreProcessor(AbcPreProcessor):
|
||||
|
||||
logging.debug("Running classification...")
|
||||
|
||||
target_id = self._classes.index(config.SVM_TARGET_CLASS_NAME) # Might raise ValueError
|
||||
|
||||
feature_vector = (mid_features - self._mean) / self._std
|
||||
class_id, probability = classifier_wrapper(self._classifier, self._model_details['type'].lower(),
|
||||
feature_vector)
|
||||
class_id, probability = classifier_wrapper(
|
||||
self._classifier, self._model_details['type'].lower(), feature_vector
|
||||
)
|
||||
class_id = int(class_id) # faszom
|
||||
|
||||
logging.debug(f"Sample {file_path} identified as {self._classes[class_id]} with the probablility of {probability[class_id]}")
|
||||
logging.debug(
|
||||
f"Sample {file_path} identified as {self._classes[class_id]} with the probablility of {probability[class_id]}"
|
||||
)
|
||||
|
||||
return bool(class_id == target_id)
|
||||
return bool(class_id == self._target_id)
|
||||
|
||||
def __del__(self):
|
||||
try:
|
||||
|
@ -27,7 +27,6 @@ MQTT_PORT = int(os.getenv("GUARD_MQTT_PORT", 1883))
|
||||
MQTT_USERNAME = os.getenv("GUARD_MQTT_USERNAME", None)
|
||||
MQTT_PASSWORD = os.getenv("GUARD_MQTT_PASSWORD", None)
|
||||
|
||||
SVM_TARGET_CLASS_NAME = os.environ.get("SVM_TARGET_CLASS_NAME", "chirp")
|
||||
SVM_MODEL_ID = os.environ.get("SVM_MODEL_ID")
|
||||
|
||||
API_URL = os.environ.get("API_URL", "http://localhost:8080")
|
Loading…
Reference in New Issue
Block a user