Added more stuff to results

This commit is contained in:
Pünkösd Marcell 2021-06-14 03:12:44 +02:00
parent a64b14b06f
commit 07689594b3
3 changed files with 27 additions and 18 deletions

View File

@ -53,7 +53,7 @@ class Classifier(object):
return target_dir, file_name # Az unknown nélkülivel kell visszatérni return target_dir, file_name # Az unknown nélkülivel kell visszatérni
def _run_predictor(self, directory: str) -> list: def _run_predictor(self, directory: str) -> Tuple[str, dict]:
predict_generator = self.datagen.flow_from_directory( predict_generator = self.datagen.flow_from_directory(
directory=directory, directory=directory,
batch_size=128, batch_size=128,
@ -77,11 +77,13 @@ class Classifier(object):
} }
labels = dict((v, k) for k, v in labels.items()) labels = dict((v, k) for k, v in labels.items())
predictions = [labels[k] for k in predicted_class_indices] labeled_predictions = {labels[i]: p for i, p in enumerate(prediction[0])}
return predictions predicted_class_name = [labels[k] for k in predicted_class_indices][0] # eh?
def predict(self, wav_filename: str) -> list: return predicted_class_name, labeled_predictions
def predict(self, wav_filename: str) -> Tuple[str, dict]:
directory, _ = self.create_spectrogram(wav_filename) directory, _ = self.create_spectrogram(wav_filename)
result = self._run_predictor(directory) result = self._run_predictor(directory)

View File

@ -14,6 +14,7 @@ class MagicDoer:
def run_everything(cls, parameters: dict) -> dict: def run_everything(cls, parameters: dict) -> dict:
tag = parameters['tag'] tag = parameters['tag']
sample_file_handle, sample_file_path = tempfile.mkstemp(prefix=f"{tag}_", suffix=".wav") sample_file_handle, sample_file_path = tempfile.mkstemp(prefix=f"{tag}_", suffix=".wav")
response = None
try: try:
# Download Sample # Download Sample
@ -29,7 +30,15 @@ class MagicDoer:
model_details, classifier = cls.classifier_cache.get_default_classifier() model_details, classifier = cls.classifier_cache.get_default_classifier()
# do the majic # do the majic
results = classifier.predict(sample_file_path) predicted_class_name, labeled_predictions = classifier.predict(sample_file_path)
response = {
"tag": tag,
"probability": labeled_predictions[model_details['target_class_name']],
"all_predictions": labeled_predictions,
"class": predicted_class_name,
"model": model_details['id']
}
finally: finally:
try: try:
@ -37,13 +46,10 @@ class MagicDoer:
except FileNotFoundError: except FileNotFoundError:
pass pass
response = { if not response:
"tag": tag, logging.error("Something went wrong during classification!")
"probability": 1.0 if results[0] == model_details['target_class_name'] else 0.0, else:
"model": model_details['id'] logging.info(f"Classification done!")
} logging.debug(f"Results: {response}")
logging.info(f"Classification done!")
logging.debug(f"Results: {response}")
return response return response

View File

@ -22,11 +22,12 @@ def message_callback(channel, method, properties, body):
with start_transaction(op="cnn-classification-service", name="classify-soundfile"): with start_transaction(op="cnn-classification-service", name="classify-soundfile"):
results = MagicDoer.run_everything(msg) # <- This is where the magic happens results = MagicDoer.run_everything(msg) # <- This is where the magic happens
channel.basic_publish( if results:
exchange=os.environ['PIKA_OUTPUT_EXCHANGE'], channel.basic_publish(
routing_key='classification-result', exchange=os.environ['PIKA_OUTPUT_EXCHANGE'],
body=json.dumps(results).encode("utf-8") routing_key='classification-result',
) body=json.dumps(results).encode("utf-8")
)
def main(): def main():