diff --git a/cnn_classification_service/cnn_classifier.py b/cnn_classification_service/cnn_classifier.py index 09a66e4..47e9201 100644 --- a/cnn_classification_service/cnn_classifier.py +++ b/cnn_classification_service/cnn_classifier.py @@ -53,7 +53,7 @@ class Classifier(object): return target_dir, file_name # Az unknown nélkülivel kell visszatérni - def _run_predictor(self, directory: str) -> list: + def _run_predictor(self, directory: str) -> Tuple[str, dict]: predict_generator = self.datagen.flow_from_directory( directory=directory, batch_size=128, @@ -77,11 +77,13 @@ class Classifier(object): } labels = dict((v, k) for k, v in labels.items()) - predictions = [labels[k] for k in predicted_class_indices] + labeled_predictions = {labels[i]: p for i, p in enumerate(prediction[0])} - return predictions + predicted_class_name = [labels[k] for k in predicted_class_indices][0] # eh? - def predict(self, wav_filename: str) -> list: + return predicted_class_name, labeled_predictions + + def predict(self, wav_filename: str) -> Tuple[str, dict]: directory, _ = self.create_spectrogram(wav_filename) result = self._run_predictor(directory) diff --git a/cnn_classification_service/magic_doer.py b/cnn_classification_service/magic_doer.py index 6a2337c..de61036 100644 --- a/cnn_classification_service/magic_doer.py +++ b/cnn_classification_service/magic_doer.py @@ -14,6 +14,7 @@ class MagicDoer: def run_everything(cls, parameters: dict) -> dict: tag = parameters['tag'] sample_file_handle, sample_file_path = tempfile.mkstemp(prefix=f"{tag}_", suffix=".wav") + response = None try: # Download Sample @@ -29,7 +30,15 @@ class MagicDoer: model_details, classifier = cls.classifier_cache.get_default_classifier() # do the majic - results = classifier.predict(sample_file_path) + predicted_class_name, labeled_predictions = classifier.predict(sample_file_path) + + response = { + "tag": tag, + "probability": labeled_predictions[model_details['target_class_name']], + "all_predictions": labeled_predictions, + "class": predicted_class_name, + "model": model_details['id'] + } finally: try: @@ -37,13 +46,10 @@ class MagicDoer: except FileNotFoundError: pass - response = { - "tag": tag, - "probability": 1.0 if results[0] == model_details['target_class_name'] else 0.0, - "model": model_details['id'] - } - - logging.info(f"Classification done!") - logging.debug(f"Results: {response}") + if not response: + logging.error("Something went wrong during classification!") + else: + logging.info(f"Classification done!") + logging.debug(f"Results: {response}") return response diff --git a/cnn_classification_service/main.py b/cnn_classification_service/main.py index 0789212..d351b0d 100644 --- a/cnn_classification_service/main.py +++ b/cnn_classification_service/main.py @@ -22,11 +22,12 @@ def message_callback(channel, method, properties, body): with start_transaction(op="cnn-classification-service", name="classify-soundfile"): results = MagicDoer.run_everything(msg) # <- This is where the magic happens - channel.basic_publish( - exchange=os.environ['PIKA_OUTPUT_EXCHANGE'], - routing_key='classification-result', - body=json.dumps(results).encode("utf-8") - ) + if results: + channel.basic_publish( + exchange=os.environ['PIKA_OUTPUT_EXCHANGE'], + routing_key='classification-result', + body=json.dumps(results).encode("utf-8") + ) def main():