Added more stuff to results
This commit is contained in:
parent
a64b14b06f
commit
07689594b3
@ -53,7 +53,7 @@ class Classifier(object):
|
|||||||
|
|
||||||
return target_dir, file_name # Az unknown nélkülivel kell visszatérni
|
return target_dir, file_name # Az unknown nélkülivel kell visszatérni
|
||||||
|
|
||||||
def _run_predictor(self, directory: str) -> list:
|
def _run_predictor(self, directory: str) -> Tuple[str, dict]:
|
||||||
predict_generator = self.datagen.flow_from_directory(
|
predict_generator = self.datagen.flow_from_directory(
|
||||||
directory=directory,
|
directory=directory,
|
||||||
batch_size=128,
|
batch_size=128,
|
||||||
@ -77,11 +77,13 @@ class Classifier(object):
|
|||||||
}
|
}
|
||||||
labels = dict((v, k) for k, v in labels.items())
|
labels = dict((v, k) for k, v in labels.items())
|
||||||
|
|
||||||
predictions = [labels[k] for k in predicted_class_indices]
|
labeled_predictions = {labels[i]: p for i, p in enumerate(prediction[0])}
|
||||||
|
|
||||||
return predictions
|
predicted_class_name = [labels[k] for k in predicted_class_indices][0] # eh?
|
||||||
|
|
||||||
def predict(self, wav_filename: str) -> list:
|
return predicted_class_name, labeled_predictions
|
||||||
|
|
||||||
|
def predict(self, wav_filename: str) -> Tuple[str, dict]:
|
||||||
directory, _ = self.create_spectrogram(wav_filename)
|
directory, _ = self.create_spectrogram(wav_filename)
|
||||||
|
|
||||||
result = self._run_predictor(directory)
|
result = self._run_predictor(directory)
|
||||||
|
@ -14,6 +14,7 @@ class MagicDoer:
|
|||||||
def run_everything(cls, parameters: dict) -> dict:
|
def run_everything(cls, parameters: dict) -> dict:
|
||||||
tag = parameters['tag']
|
tag = parameters['tag']
|
||||||
sample_file_handle, sample_file_path = tempfile.mkstemp(prefix=f"{tag}_", suffix=".wav")
|
sample_file_handle, sample_file_path = tempfile.mkstemp(prefix=f"{tag}_", suffix=".wav")
|
||||||
|
response = None
|
||||||
try:
|
try:
|
||||||
|
|
||||||
# Download Sample
|
# Download Sample
|
||||||
@ -29,7 +30,15 @@ class MagicDoer:
|
|||||||
model_details, classifier = cls.classifier_cache.get_default_classifier()
|
model_details, classifier = cls.classifier_cache.get_default_classifier()
|
||||||
|
|
||||||
# do the majic
|
# do the majic
|
||||||
results = classifier.predict(sample_file_path)
|
predicted_class_name, labeled_predictions = classifier.predict(sample_file_path)
|
||||||
|
|
||||||
|
response = {
|
||||||
|
"tag": tag,
|
||||||
|
"probability": labeled_predictions[model_details['target_class_name']],
|
||||||
|
"all_predictions": labeled_predictions,
|
||||||
|
"class": predicted_class_name,
|
||||||
|
"model": model_details['id']
|
||||||
|
}
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
try:
|
try:
|
||||||
@ -37,13 +46,10 @@ class MagicDoer:
|
|||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
response = {
|
if not response:
|
||||||
"tag": tag,
|
logging.error("Something went wrong during classification!")
|
||||||
"probability": 1.0 if results[0] == model_details['target_class_name'] else 0.0,
|
else:
|
||||||
"model": model_details['id']
|
logging.info(f"Classification done!")
|
||||||
}
|
logging.debug(f"Results: {response}")
|
||||||
|
|
||||||
logging.info(f"Classification done!")
|
|
||||||
logging.debug(f"Results: {response}")
|
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
@ -22,11 +22,12 @@ def message_callback(channel, method, properties, body):
|
|||||||
with start_transaction(op="cnn-classification-service", name="classify-soundfile"):
|
with start_transaction(op="cnn-classification-service", name="classify-soundfile"):
|
||||||
results = MagicDoer.run_everything(msg) # <- This is where the magic happens
|
results = MagicDoer.run_everything(msg) # <- This is where the magic happens
|
||||||
|
|
||||||
channel.basic_publish(
|
if results:
|
||||||
exchange=os.environ['PIKA_OUTPUT_EXCHANGE'],
|
channel.basic_publish(
|
||||||
routing_key='classification-result',
|
exchange=os.environ['PIKA_OUTPUT_EXCHANGE'],
|
||||||
body=json.dumps(results).encode("utf-8")
|
routing_key='classification-result',
|
||||||
)
|
body=json.dumps(results).encode("utf-8")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
Loading…
Reference in New Issue
Block a user