Updated local Ai internals
	
		
			
	
		
	
	
		
	
		
			All checks were successful
		
		
	
	
		
			
				
	
				continuous-integration/drone/push Build is passing
				
			
		
		
	
	
				
					
				
			
		
			All checks were successful
		
		
	
	continuous-integration/drone/push Build is passing
				
			This commit is contained in:
		| @@ -11,6 +11,8 @@ from pyAudioAnalysis import audioBasicIO | |||||||
| from pyAudioAnalysis import MidTermFeatures | from pyAudioAnalysis import MidTermFeatures | ||||||
| import numpy | import numpy | ||||||
|  |  | ||||||
|  | from birbnetes_iot_platform_raspberry import BirbnetesIoTPlatformStatusDriver | ||||||
|  |  | ||||||
| """ | """ | ||||||
| Abstract base class for Sender | Abstract base class for Sender | ||||||
| """ | """ | ||||||
| @@ -28,11 +30,11 @@ class SoundPreProcessor(AbcPreProcessor): | |||||||
|  |  | ||||||
|     def __init__(self): |     def __init__(self): | ||||||
|         logging.info("Downloading current model...") |         logging.info("Downloading current model...") | ||||||
|         _, self._temp_model_name = tempfile.mkstemp() |         temp_model_handle, self._temp_model_name = tempfile.mkstemp() | ||||||
|         self._temp_means_name = self._temp_model_name + "MEANS" |         self._temp_means_name = self._temp_model_name + "MEANS" | ||||||
|  |  | ||||||
|         logging.debug("Fetching model info...") |         logging.debug("Fetching model info...") | ||||||
|  |         BirbnetesIoTPlatformStatusDriver.enqueue_pattern('green', [1]) | ||||||
|         if config.SVM_MODEL_ID: |         if config.SVM_MODEL_ID: | ||||||
|             model_id_to_get = config.SVM_MODEL_ID |             model_id_to_get = config.SVM_MODEL_ID | ||||||
|         else: |         else: | ||||||
| @@ -44,13 +46,15 @@ class SoundPreProcessor(AbcPreProcessor): | |||||||
|         self._model_details = r.json() |         self._model_details = r.json() | ||||||
|  |  | ||||||
|         logging.debug("Downloading model...") |         logging.debug("Downloading model...") | ||||||
|  |         BirbnetesIoTPlatformStatusDriver.enqueue_pattern('green', [1]) | ||||||
|         r = requests.get(f"{config.API_URL}/model/svm/{self._model_details['id']}") |         r = requests.get(f"{config.API_URL}/model/svm/{self._model_details['id']}") | ||||||
|         r.raise_for_status() |         r.raise_for_status() | ||||||
|  |  | ||||||
|         with open(self._temp_model_name, 'wb') as f: |         with open(temp_model_handle, 'wb') as f:  # bruhtastic | ||||||
|             f.write(r.content) |             f.write(r.content) | ||||||
|  |  | ||||||
|         logging.debug("Downloading MEANS...") |         logging.debug("Downloading MEANS...") | ||||||
|  |         BirbnetesIoTPlatformStatusDriver.enqueue_pattern('green', [1]) | ||||||
|         r = requests.get(f"{config.API_URL}/model/svm/{self._model_details['id']}?means") |         r = requests.get(f"{config.API_URL}/model/svm/{self._model_details['id']}?means") | ||||||
|         r.raise_for_status() |         r.raise_for_status() | ||||||
|  |  | ||||||
| @@ -69,9 +73,14 @@ class SoundPreProcessor(AbcPreProcessor): | |||||||
|             self._mid_window, self._mid_step, self._short_window, \ |             self._mid_window, self._mid_step, self._short_window, \ | ||||||
|             self._short_step, self._compute_beat = load_model(self._temp_model_name) |             self._short_step, self._compute_beat = load_model(self._temp_model_name) | ||||||
|  |  | ||||||
|  |         target_class_name = self._model_details['target_class_name'] | ||||||
|  |  | ||||||
|         logging.info("The loaded model contains the following classes: " + ", ".join(self._classes)) |         logging.info("The loaded model contains the following classes: " + ", ".join(self._classes)) | ||||||
|         if config.SVM_TARGET_CLASS_NAME not in self._classes: |         if target_class_name not in self._classes: | ||||||
|             raise Exception(f"The specified target class {config.SVM_TARGET_CLASS_NAME} is not in the possible classes") |             raise Exception( | ||||||
|  |                 f"The specified target class {target_class_name} is not in the possible classes (Wrong model info?)") | ||||||
|  |  | ||||||
|  |         self._target_id = self._classes.index(target_class_name) | ||||||
|  |  | ||||||
|     def preprocesssignal(self, file_path: str) -> bool: |     def preprocesssignal(self, file_path: str) -> bool: | ||||||
|         """ |         """ | ||||||
| @@ -107,16 +116,17 @@ class SoundPreProcessor(AbcPreProcessor): | |||||||
|  |  | ||||||
|         logging.debug("Running classification...") |         logging.debug("Running classification...") | ||||||
|  |  | ||||||
|         target_id = self._classes.index(config.SVM_TARGET_CLASS_NAME)  # Might raise ValueError |  | ||||||
|  |  | ||||||
|         feature_vector = (mid_features - self._mean) / self._std |         feature_vector = (mid_features - self._mean) / self._std | ||||||
|         class_id, probability = classifier_wrapper(self._classifier, self._model_details['type'].lower(), |         class_id, probability = classifier_wrapper( | ||||||
|                                                    feature_vector) |             self._classifier, self._model_details['type'].lower(), feature_vector | ||||||
|  |         ) | ||||||
|         class_id = int(class_id)  # faszom |         class_id = int(class_id)  # faszom | ||||||
|  |  | ||||||
|         logging.debug(f"Sample {file_path} identified as {self._classes[class_id]} with the probablility of {probability[class_id]}") |         logging.debug( | ||||||
|  |             f"Sample {file_path} identified as {self._classes[class_id]} with the probablility of {probability[class_id]}" | ||||||
|  |         ) | ||||||
|  |  | ||||||
|         return bool(class_id == target_id) |         return bool(class_id == self._target_id) | ||||||
|  |  | ||||||
|     def __del__(self): |     def __del__(self): | ||||||
|         try: |         try: | ||||||
|   | |||||||
| @@ -27,7 +27,6 @@ MQTT_PORT = int(os.getenv("GUARD_MQTT_PORT", 1883)) | |||||||
| MQTT_USERNAME = os.getenv("GUARD_MQTT_USERNAME", None) | MQTT_USERNAME = os.getenv("GUARD_MQTT_USERNAME", None) | ||||||
| MQTT_PASSWORD = os.getenv("GUARD_MQTT_PASSWORD", None) | MQTT_PASSWORD = os.getenv("GUARD_MQTT_PASSWORD", None) | ||||||
|  |  | ||||||
| SVM_TARGET_CLASS_NAME = os.environ.get("SVM_TARGET_CLASS_NAME", "chirp") |  | ||||||
| SVM_MODEL_ID = os.environ.get("SVM_MODEL_ID") | SVM_MODEL_ID = os.environ.get("SVM_MODEL_ID") | ||||||
|  |  | ||||||
| API_URL = os.environ.get("API_URL", "http://localhost:8080") | API_URL = os.environ.get("API_URL", "http://localhost:8080") | ||||||
		Reference in New Issue
	
	Block a user