diff --git a/cnn_classification_service/cnn_classifier.py b/cnn_classification_service/cnn_classifier.py index 47e9201..083ed02 100644 --- a/cnn_classification_service/cnn_classifier.py +++ b/cnn_classification_service/cnn_classifier.py @@ -77,7 +77,7 @@ class Classifier(object): } labels = dict((v, k) for k, v in labels.items()) - labeled_predictions = {labels[i]: p for i, p in enumerate(prediction[0])} + labeled_predictions = {labels[i]: float(p) for i, p in enumerate(prediction[0])} predicted_class_name = [labels[k] for k in predicted_class_indices][0] # eh?