diff --git a/src/preprocessor/soundpreprocessor.py b/src/preprocessor/soundpreprocessor.py index a3ce688..adcc9c4 100644 --- a/src/preprocessor/soundpreprocessor.py +++ b/src/preprocessor/soundpreprocessor.py @@ -32,7 +32,13 @@ class SoundPreProcessor(AbcPreProcessor): self._temp_means_name = self._temp_model_name + "MEANS" logging.debug("Fetching model info...") - r = requests.get(f"{config.API_URL}/model/svm/$default/details") + + if config.SVM_MODEL_ID: + model_id_to_get = config.SVM_MODEL_ID + else: + model_id_to_get = '$default' + + r = requests.get(f"{config.API_URL}/model/svm/{model_id_to_get}/details") r.raise_for_status() self._model_details = r.json() @@ -97,7 +103,7 @@ class SoundPreProcessor(AbcPreProcessor): logging.info("Running classification...") - target_id = self._classes.index('chirp') # Might raise ValueError + target_id = self._classes.index(config.SVM_TARGET_CLASS_NAME) # Might raise ValueError feature_vector = (mid_features - self._mean) / self._std class_id, probability = classifier_wrapper(self._classifier, self._model_details['type'], feature_vector) diff --git a/src/utils/config.py b/src/utils/config.py index 0e9f44e..67d5328 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -27,7 +27,7 @@ MQTT_PORT = int(os.getenv("GUARD_MQTT_PORT", 1883)) MQTT_USERNAME = os.getenv("GUARD_MQTT_USERNAME", None) MQTT_PASSWORD = os.getenv("GUARD_MQTT_PASSWORD", None) -TARGET_NAME = os.environ.get("TARGET_CLASS_NAME") -MODEL_ID = os.environ.get("MODEL_ID") +SVM_TARGET_CLASS_NAME = os.environ.get("SVM_TARGET_CLASS_NAME", "chirp") +SVM_MODEL_ID = os.environ.get("SVM_MODEL_ID") API_URL = os.environ.get("API_URL", "http://localhost:8080") \ No newline at end of file