Updated local Ai internals
	
		
			
	
		
	
	
		
	
		
			All checks were successful
		
		
	
	
		
			
				
	
				continuous-integration/drone/push Build is passing
				
			
		
		
	
	
				
					
				
			
		
			All checks were successful
		
		
	
	continuous-integration/drone/push Build is passing
				
			This commit is contained in:
		@@ -11,6 +11,8 @@ from pyAudioAnalysis import audioBasicIO
 | 
			
		||||
from pyAudioAnalysis import MidTermFeatures
 | 
			
		||||
import numpy
 | 
			
		||||
 | 
			
		||||
from birbnetes_iot_platform_raspberry import BirbnetesIoTPlatformStatusDriver
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
Abstract base class for Sender
 | 
			
		||||
"""
 | 
			
		||||
@@ -28,11 +30,11 @@ class SoundPreProcessor(AbcPreProcessor):
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        logging.info("Downloading current model...")
 | 
			
		||||
        _, self._temp_model_name = tempfile.mkstemp()
 | 
			
		||||
        temp_model_handle, self._temp_model_name = tempfile.mkstemp()
 | 
			
		||||
        self._temp_means_name = self._temp_model_name + "MEANS"
 | 
			
		||||
 | 
			
		||||
        logging.debug("Fetching model info...")
 | 
			
		||||
 | 
			
		||||
        BirbnetesIoTPlatformStatusDriver.enqueue_pattern('green', [1])
 | 
			
		||||
        if config.SVM_MODEL_ID:
 | 
			
		||||
            model_id_to_get = config.SVM_MODEL_ID
 | 
			
		||||
        else:
 | 
			
		||||
@@ -44,13 +46,15 @@ class SoundPreProcessor(AbcPreProcessor):
 | 
			
		||||
        self._model_details = r.json()
 | 
			
		||||
 | 
			
		||||
        logging.debug("Downloading model...")
 | 
			
		||||
        BirbnetesIoTPlatformStatusDriver.enqueue_pattern('green', [1])
 | 
			
		||||
        r = requests.get(f"{config.API_URL}/model/svm/{self._model_details['id']}")
 | 
			
		||||
        r.raise_for_status()
 | 
			
		||||
 | 
			
		||||
        with open(self._temp_model_name, 'wb') as f:
 | 
			
		||||
        with open(temp_model_handle, 'wb') as f:  # bruhtastic
 | 
			
		||||
            f.write(r.content)
 | 
			
		||||
 | 
			
		||||
        logging.debug("Downloading MEANS...")
 | 
			
		||||
        BirbnetesIoTPlatformStatusDriver.enqueue_pattern('green', [1])
 | 
			
		||||
        r = requests.get(f"{config.API_URL}/model/svm/{self._model_details['id']}?means")
 | 
			
		||||
        r.raise_for_status()
 | 
			
		||||
 | 
			
		||||
@@ -69,9 +73,14 @@ class SoundPreProcessor(AbcPreProcessor):
 | 
			
		||||
            self._mid_window, self._mid_step, self._short_window, \
 | 
			
		||||
            self._short_step, self._compute_beat = load_model(self._temp_model_name)
 | 
			
		||||
 | 
			
		||||
        target_class_name = self._model_details['target_class_name']
 | 
			
		||||
 | 
			
		||||
        logging.info("The loaded model contains the following classes: " + ", ".join(self._classes))
 | 
			
		||||
        if config.SVM_TARGET_CLASS_NAME not in self._classes:
 | 
			
		||||
            raise Exception(f"The specified target class {config.SVM_TARGET_CLASS_NAME} is not in the possible classes")
 | 
			
		||||
        if target_class_name not in self._classes:
 | 
			
		||||
            raise Exception(
 | 
			
		||||
                f"The specified target class {target_class_name} is not in the possible classes (Wrong model info?)")
 | 
			
		||||
 | 
			
		||||
        self._target_id = self._classes.index(target_class_name)
 | 
			
		||||
 | 
			
		||||
    def preprocesssignal(self, file_path: str) -> bool:
 | 
			
		||||
        """
 | 
			
		||||
@@ -107,16 +116,17 @@ class SoundPreProcessor(AbcPreProcessor):
 | 
			
		||||
 | 
			
		||||
        logging.debug("Running classification...")
 | 
			
		||||
 | 
			
		||||
        target_id = self._classes.index(config.SVM_TARGET_CLASS_NAME)  # Might raise ValueError
 | 
			
		||||
 | 
			
		||||
        feature_vector = (mid_features - self._mean) / self._std
 | 
			
		||||
        class_id, probability = classifier_wrapper(self._classifier, self._model_details['type'].lower(),
 | 
			
		||||
                                                   feature_vector)
 | 
			
		||||
        class_id, probability = classifier_wrapper(
 | 
			
		||||
            self._classifier, self._model_details['type'].lower(), feature_vector
 | 
			
		||||
        )
 | 
			
		||||
        class_id = int(class_id)  # faszom
 | 
			
		||||
 | 
			
		||||
        logging.debug(f"Sample {file_path} identified as {self._classes[class_id]} with the probablility of {probability[class_id]}")
 | 
			
		||||
        logging.debug(
 | 
			
		||||
            f"Sample {file_path} identified as {self._classes[class_id]} with the probablility of {probability[class_id]}"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        return bool(class_id == target_id)
 | 
			
		||||
        return bool(class_id == self._target_id)
 | 
			
		||||
 | 
			
		||||
    def __del__(self):
 | 
			
		||||
        try:
 | 
			
		||||
 
 | 
			
		||||
@@ -27,7 +27,6 @@ MQTT_PORT = int(os.getenv("GUARD_MQTT_PORT", 1883))
 | 
			
		||||
MQTT_USERNAME = os.getenv("GUARD_MQTT_USERNAME", None)
 | 
			
		||||
MQTT_PASSWORD = os.getenv("GUARD_MQTT_PASSWORD", None)
 | 
			
		||||
 | 
			
		||||
SVM_TARGET_CLASS_NAME = os.environ.get("SVM_TARGET_CLASS_NAME", "chirp")
 | 
			
		||||
SVM_MODEL_ID = os.environ.get("SVM_MODEL_ID")
 | 
			
		||||
 | 
			
		||||
API_URL = os.environ.get("API_URL", "http://localhost:8080")
 | 
			
		||||
		Reference in New Issue
	
	Block a user