Added more stuff to results
This commit is contained in:
parent
a64b14b06f
commit
07689594b3
@ -53,7 +53,7 @@ class Classifier(object):
|
||||
|
||||
return target_dir, file_name # Az unknown nélkülivel kell visszatérni
|
||||
|
||||
def _run_predictor(self, directory: str) -> list:
|
||||
def _run_predictor(self, directory: str) -> Tuple[str, dict]:
|
||||
predict_generator = self.datagen.flow_from_directory(
|
||||
directory=directory,
|
||||
batch_size=128,
|
||||
@ -77,11 +77,13 @@ class Classifier(object):
|
||||
}
|
||||
labels = dict((v, k) for k, v in labels.items())
|
||||
|
||||
predictions = [labels[k] for k in predicted_class_indices]
|
||||
labeled_predictions = {labels[i]: p for i, p in enumerate(prediction[0])}
|
||||
|
||||
return predictions
|
||||
predicted_class_name = [labels[k] for k in predicted_class_indices][0] # eh?
|
||||
|
||||
def predict(self, wav_filename: str) -> list:
|
||||
return predicted_class_name, labeled_predictions
|
||||
|
||||
def predict(self, wav_filename: str) -> Tuple[str, dict]:
|
||||
directory, _ = self.create_spectrogram(wav_filename)
|
||||
|
||||
result = self._run_predictor(directory)
|
||||
|
@ -14,6 +14,7 @@ class MagicDoer:
|
||||
def run_everything(cls, parameters: dict) -> dict:
|
||||
tag = parameters['tag']
|
||||
sample_file_handle, sample_file_path = tempfile.mkstemp(prefix=f"{tag}_", suffix=".wav")
|
||||
response = None
|
||||
try:
|
||||
|
||||
# Download Sample
|
||||
@ -29,7 +30,15 @@ class MagicDoer:
|
||||
model_details, classifier = cls.classifier_cache.get_default_classifier()
|
||||
|
||||
# do the majic
|
||||
results = classifier.predict(sample_file_path)
|
||||
predicted_class_name, labeled_predictions = classifier.predict(sample_file_path)
|
||||
|
||||
response = {
|
||||
"tag": tag,
|
||||
"probability": labeled_predictions[model_details['target_class_name']],
|
||||
"all_predictions": labeled_predictions,
|
||||
"class": predicted_class_name,
|
||||
"model": model_details['id']
|
||||
}
|
||||
|
||||
finally:
|
||||
try:
|
||||
@ -37,13 +46,10 @@ class MagicDoer:
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
response = {
|
||||
"tag": tag,
|
||||
"probability": 1.0 if results[0] == model_details['target_class_name'] else 0.0,
|
||||
"model": model_details['id']
|
||||
}
|
||||
|
||||
logging.info(f"Classification done!")
|
||||
logging.debug(f"Results: {response}")
|
||||
if not response:
|
||||
logging.error("Something went wrong during classification!")
|
||||
else:
|
||||
logging.info(f"Classification done!")
|
||||
logging.debug(f"Results: {response}")
|
||||
|
||||
return response
|
||||
|
@ -22,11 +22,12 @@ def message_callback(channel, method, properties, body):
|
||||
with start_transaction(op="cnn-classification-service", name="classify-soundfile"):
|
||||
results = MagicDoer.run_everything(msg) # <- This is where the magic happens
|
||||
|
||||
channel.basic_publish(
|
||||
exchange=os.environ['PIKA_OUTPUT_EXCHANGE'],
|
||||
routing_key='classification-result',
|
||||
body=json.dumps(results).encode("utf-8")
|
||||
)
|
||||
if results:
|
||||
channel.basic_publish(
|
||||
exchange=os.environ['PIKA_OUTPUT_EXCHANGE'],
|
||||
routing_key='classification-result',
|
||||
body=json.dumps(results).encode("utf-8")
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
|
Loading…
Reference in New Issue
Block a user