Added chirp targeting
This commit is contained in:
		@@ -20,7 +20,7 @@ if SENTRY_DSN:
 | 
				
			|||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def run_classification(task):
 | 
					def run_classification(task, target_class_name: str):
 | 
				
			||||||
    _, temp_model_name = tempfile.mkstemp()
 | 
					    _, temp_model_name = tempfile.mkstemp()
 | 
				
			||||||
    temp_means_name = temp_model_name + "MEANS"
 | 
					    temp_means_name = temp_model_name + "MEANS"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -51,6 +51,8 @@ def run_classification(task):
 | 
				
			|||||||
            classifier, mean, std, classes, mid_window, mid_step, short_window, short_step, compute_beat \
 | 
					            classifier, mean, std, classes, mid_window, mid_step, short_window, short_step, compute_beat \
 | 
				
			||||||
                = load_model(temp_model_name)
 | 
					                = load_model(temp_model_name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        target_id = classes.index(target_class_name)  # Might raise ValueError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        feature_vector = (numpy.array(task['features']) - mean) / std
 | 
					        feature_vector = (numpy.array(task['features']) - mean) / std
 | 
				
			||||||
        class_id, probability = classifier_wrapper(classifier, model_details['type'], feature_vector)
 | 
					        class_id, probability = classifier_wrapper(classifier, model_details['type'], feature_vector)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -68,8 +70,8 @@ def run_classification(task):
 | 
				
			|||||||
    results = {
 | 
					    results = {
 | 
				
			||||||
        "tag": task['tag'],
 | 
					        "tag": task['tag'],
 | 
				
			||||||
        "model": task['model'],
 | 
					        "model": task['model'],
 | 
				
			||||||
        "class_id": class_id,
 | 
					        "is_target": class_id == target_id,
 | 
				
			||||||
        "probability": probability
 | 
					        "probability": probability[target_id]
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return results
 | 
					    return results
 | 
				
			||||||
@@ -83,8 +85,9 @@ def main():
 | 
				
			|||||||
    while True:
 | 
					    while True:
 | 
				
			||||||
        message = uwsgi.mule_get_msg()
 | 
					        message = uwsgi.mule_get_msg()
 | 
				
			||||||
        task = json.loads(message)
 | 
					        task = json.loads(message)
 | 
				
			||||||
        results = run_classification(task)
 | 
					        results = run_classification(task, os.environ['TARGET_CLASS_NAME'])
 | 
				
			||||||
        channel.basic_publish(exchange=os.environ['PIKA_EXCHANGE'], routing_key='classification-result', body=json.dumps(results).encode("utf-8"))
 | 
					        channel.basic_publish(exchange=os.environ['PIKA_EXCHANGE'], routing_key='classification-result',
 | 
				
			||||||
 | 
					                              body=json.dumps(results).encode("utf-8"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == '__main__':
 | 
					if __name__ == '__main__':
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user