Updated local Ai internals
All checks were successful
continuous-integration/drone/push Build is passing
All checks were successful
continuous-integration/drone/push Build is passing
This commit is contained in:
parent
a223c8e95e
commit
2238fe08f9
@ -11,6 +11,8 @@ from pyAudioAnalysis import audioBasicIO
|
|||||||
from pyAudioAnalysis import MidTermFeatures
|
from pyAudioAnalysis import MidTermFeatures
|
||||||
import numpy
|
import numpy
|
||||||
|
|
||||||
|
from birbnetes_iot_platform_raspberry import BirbnetesIoTPlatformStatusDriver
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Abstract base class for Sender
|
Abstract base class for Sender
|
||||||
"""
|
"""
|
||||||
@ -28,11 +30,11 @@ class SoundPreProcessor(AbcPreProcessor):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
logging.info("Downloading current model...")
|
logging.info("Downloading current model...")
|
||||||
_, self._temp_model_name = tempfile.mkstemp()
|
temp_model_handle, self._temp_model_name = tempfile.mkstemp()
|
||||||
self._temp_means_name = self._temp_model_name + "MEANS"
|
self._temp_means_name = self._temp_model_name + "MEANS"
|
||||||
|
|
||||||
logging.debug("Fetching model info...")
|
logging.debug("Fetching model info...")
|
||||||
|
BirbnetesIoTPlatformStatusDriver.enqueue_pattern('green', [1])
|
||||||
if config.SVM_MODEL_ID:
|
if config.SVM_MODEL_ID:
|
||||||
model_id_to_get = config.SVM_MODEL_ID
|
model_id_to_get = config.SVM_MODEL_ID
|
||||||
else:
|
else:
|
||||||
@ -44,13 +46,15 @@ class SoundPreProcessor(AbcPreProcessor):
|
|||||||
self._model_details = r.json()
|
self._model_details = r.json()
|
||||||
|
|
||||||
logging.debug("Downloading model...")
|
logging.debug("Downloading model...")
|
||||||
|
BirbnetesIoTPlatformStatusDriver.enqueue_pattern('green', [1])
|
||||||
r = requests.get(f"{config.API_URL}/model/svm/{self._model_details['id']}")
|
r = requests.get(f"{config.API_URL}/model/svm/{self._model_details['id']}")
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
|
|
||||||
with open(self._temp_model_name, 'wb') as f:
|
with open(temp_model_handle, 'wb') as f: # bruhtastic
|
||||||
f.write(r.content)
|
f.write(r.content)
|
||||||
|
|
||||||
logging.debug("Downloading MEANS...")
|
logging.debug("Downloading MEANS...")
|
||||||
|
BirbnetesIoTPlatformStatusDriver.enqueue_pattern('green', [1])
|
||||||
r = requests.get(f"{config.API_URL}/model/svm/{self._model_details['id']}?means")
|
r = requests.get(f"{config.API_URL}/model/svm/{self._model_details['id']}?means")
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
|
|
||||||
@ -69,9 +73,14 @@ class SoundPreProcessor(AbcPreProcessor):
|
|||||||
self._mid_window, self._mid_step, self._short_window, \
|
self._mid_window, self._mid_step, self._short_window, \
|
||||||
self._short_step, self._compute_beat = load_model(self._temp_model_name)
|
self._short_step, self._compute_beat = load_model(self._temp_model_name)
|
||||||
|
|
||||||
|
target_class_name = self._model_details['target_class_name']
|
||||||
|
|
||||||
logging.info("The loaded model contains the following classes: " + ", ".join(self._classes))
|
logging.info("The loaded model contains the following classes: " + ", ".join(self._classes))
|
||||||
if config.SVM_TARGET_CLASS_NAME not in self._classes:
|
if target_class_name not in self._classes:
|
||||||
raise Exception(f"The specified target class {config.SVM_TARGET_CLASS_NAME} is not in the possible classes")
|
raise Exception(
|
||||||
|
f"The specified target class {target_class_name} is not in the possible classes (Wrong model info?)")
|
||||||
|
|
||||||
|
self._target_id = self._classes.index(target_class_name)
|
||||||
|
|
||||||
def preprocesssignal(self, file_path: str) -> bool:
|
def preprocesssignal(self, file_path: str) -> bool:
|
||||||
"""
|
"""
|
||||||
@ -107,16 +116,17 @@ class SoundPreProcessor(AbcPreProcessor):
|
|||||||
|
|
||||||
logging.debug("Running classification...")
|
logging.debug("Running classification...")
|
||||||
|
|
||||||
target_id = self._classes.index(config.SVM_TARGET_CLASS_NAME) # Might raise ValueError
|
|
||||||
|
|
||||||
feature_vector = (mid_features - self._mean) / self._std
|
feature_vector = (mid_features - self._mean) / self._std
|
||||||
class_id, probability = classifier_wrapper(self._classifier, self._model_details['type'].lower(),
|
class_id, probability = classifier_wrapper(
|
||||||
feature_vector)
|
self._classifier, self._model_details['type'].lower(), feature_vector
|
||||||
|
)
|
||||||
class_id = int(class_id) # faszom
|
class_id = int(class_id) # faszom
|
||||||
|
|
||||||
logging.debug(f"Sample {file_path} identified as {self._classes[class_id]} with the probablility of {probability[class_id]}")
|
logging.debug(
|
||||||
|
f"Sample {file_path} identified as {self._classes[class_id]} with the probablility of {probability[class_id]}"
|
||||||
|
)
|
||||||
|
|
||||||
return bool(class_id == target_id)
|
return bool(class_id == self._target_id)
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
try:
|
try:
|
||||||
|
@ -27,7 +27,6 @@ MQTT_PORT = int(os.getenv("GUARD_MQTT_PORT", 1883))
|
|||||||
MQTT_USERNAME = os.getenv("GUARD_MQTT_USERNAME", None)
|
MQTT_USERNAME = os.getenv("GUARD_MQTT_USERNAME", None)
|
||||||
MQTT_PASSWORD = os.getenv("GUARD_MQTT_PASSWORD", None)
|
MQTT_PASSWORD = os.getenv("GUARD_MQTT_PASSWORD", None)
|
||||||
|
|
||||||
SVM_TARGET_CLASS_NAME = os.environ.get("SVM_TARGET_CLASS_NAME", "chirp")
|
|
||||||
SVM_MODEL_ID = os.environ.get("SVM_MODEL_ID")
|
SVM_MODEL_ID = os.environ.get("SVM_MODEL_ID")
|
||||||
|
|
||||||
API_URL = os.environ.get("API_URL", "http://localhost:8080")
|
API_URL = os.environ.get("API_URL", "http://localhost:8080")
|
Loading…
Reference in New Issue
Block a user