Added more stuff to results

This commit is contained in:
Pünkösd Marcell 2021-06-14 03:12:44 +02:00
parent a64b14b06f
commit 07689594b3
3 changed files with 27 additions and 18 deletions

View File

@ -53,7 +53,7 @@ class Classifier(object):
return target_dir, file_name # Az unknown nélkülivel kell visszatérni
def _run_predictor(self, directory: str) -> list:
def _run_predictor(self, directory: str) -> Tuple[str, dict]:
predict_generator = self.datagen.flow_from_directory(
directory=directory,
batch_size=128,
@ -77,11 +77,13 @@ class Classifier(object):
}
labels = dict((v, k) for k, v in labels.items())
predictions = [labels[k] for k in predicted_class_indices]
labeled_predictions = {labels[i]: p for i, p in enumerate(prediction[0])}
return predictions
predicted_class_name = [labels[k] for k in predicted_class_indices][0] # eh?
def predict(self, wav_filename: str) -> list:
return predicted_class_name, labeled_predictions
def predict(self, wav_filename: str) -> Tuple[str, dict]:
directory, _ = self.create_spectrogram(wav_filename)
result = self._run_predictor(directory)

View File

@ -14,6 +14,7 @@ class MagicDoer:
def run_everything(cls, parameters: dict) -> dict:
tag = parameters['tag']
sample_file_handle, sample_file_path = tempfile.mkstemp(prefix=f"{tag}_", suffix=".wav")
response = None
try:
# Download Sample
@ -29,7 +30,15 @@ class MagicDoer:
model_details, classifier = cls.classifier_cache.get_default_classifier()
# do the majic
results = classifier.predict(sample_file_path)
predicted_class_name, labeled_predictions = classifier.predict(sample_file_path)
response = {
"tag": tag,
"probability": labeled_predictions[model_details['target_class_name']],
"all_predictions": labeled_predictions,
"class": predicted_class_name,
"model": model_details['id']
}
finally:
try:
@ -37,12 +46,9 @@ class MagicDoer:
except FileNotFoundError:
pass
response = {
"tag": tag,
"probability": 1.0 if results[0] == model_details['target_class_name'] else 0.0,
"model": model_details['id']
}
if not response:
logging.error("Something went wrong during classification!")
else:
logging.info(f"Classification done!")
logging.debug(f"Results: {response}")

View File

@ -22,6 +22,7 @@ def message_callback(channel, method, properties, body):
with start_transaction(op="cnn-classification-service", name="classify-soundfile"):
results = MagicDoer.run_everything(msg) # <- This is where the magic happens
if results:
channel.basic_publish(
exchange=os.environ['PIKA_OUTPUT_EXCHANGE'],
routing_key='classification-result',