diff --git a/classification_service/mule.py b/classification_service/mule.py index cd197d8..4655595 100644 --- a/classification_service/mule.py +++ b/classification_service/mule.py @@ -20,7 +20,7 @@ if SENTRY_DSN: ) -def run_classification(task): +def run_classification(task, target_class_name: str): _, temp_model_name = tempfile.mkstemp() temp_means_name = temp_model_name + "MEANS" @@ -51,6 +51,8 @@ def run_classification(task): classifier, mean, std, classes, mid_window, mid_step, short_window, short_step, compute_beat \ = load_model(temp_model_name) + target_id = classes.index(target_class_name) # Might raise ValueError + feature_vector = (numpy.array(task['features']) - mean) / std class_id, probability = classifier_wrapper(classifier, model_details['type'], feature_vector) @@ -68,8 +70,8 @@ def run_classification(task): results = { "tag": task['tag'], "model": task['model'], - "class_id": class_id, - "probability": probability + "is_target": class_id == target_id, + "probability": probability[target_id] } return results @@ -83,8 +85,9 @@ def main(): while True: message = uwsgi.mule_get_msg() task = json.loads(message) - results = run_classification(task) - channel.basic_publish(exchange=os.environ['PIKA_EXCHANGE'], routing_key='classification-result', body=json.dumps(results).encode("utf-8")) + results = run_classification(task, os.environ['TARGET_CLASS_NAME']) + channel.basic_publish(exchange=os.environ['PIKA_EXCHANGE'], routing_key='classification-result', + body=json.dumps(results).encode("utf-8")) if __name__ == '__main__':