Added chirp targeting
This commit is contained in:
parent
bc88a6aa45
commit
fcbf54e90b
@ -20,7 +20,7 @@ if SENTRY_DSN:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def run_classification(task):
|
def run_classification(task, target_class_name: str):
|
||||||
_, temp_model_name = tempfile.mkstemp()
|
_, temp_model_name = tempfile.mkstemp()
|
||||||
temp_means_name = temp_model_name + "MEANS"
|
temp_means_name = temp_model_name + "MEANS"
|
||||||
|
|
||||||
@ -51,6 +51,8 @@ def run_classification(task):
|
|||||||
classifier, mean, std, classes, mid_window, mid_step, short_window, short_step, compute_beat \
|
classifier, mean, std, classes, mid_window, mid_step, short_window, short_step, compute_beat \
|
||||||
= load_model(temp_model_name)
|
= load_model(temp_model_name)
|
||||||
|
|
||||||
|
target_id = classes.index(target_class_name) # Might raise ValueError
|
||||||
|
|
||||||
feature_vector = (numpy.array(task['features']) - mean) / std
|
feature_vector = (numpy.array(task['features']) - mean) / std
|
||||||
class_id, probability = classifier_wrapper(classifier, model_details['type'], feature_vector)
|
class_id, probability = classifier_wrapper(classifier, model_details['type'], feature_vector)
|
||||||
|
|
||||||
@ -68,8 +70,8 @@ def run_classification(task):
|
|||||||
results = {
|
results = {
|
||||||
"tag": task['tag'],
|
"tag": task['tag'],
|
||||||
"model": task['model'],
|
"model": task['model'],
|
||||||
"class_id": class_id,
|
"is_target": class_id == target_id,
|
||||||
"probability": probability
|
"probability": probability[target_id]
|
||||||
}
|
}
|
||||||
|
|
||||||
return results
|
return results
|
||||||
@ -83,8 +85,9 @@ def main():
|
|||||||
while True:
|
while True:
|
||||||
message = uwsgi.mule_get_msg()
|
message = uwsgi.mule_get_msg()
|
||||||
task = json.loads(message)
|
task = json.loads(message)
|
||||||
results = run_classification(task)
|
results = run_classification(task, os.environ['TARGET_CLASS_NAME'])
|
||||||
channel.basic_publish(exchange=os.environ['PIKA_EXCHANGE'], routing_key='classification-result', body=json.dumps(results).encode("utf-8"))
|
channel.basic_publish(exchange=os.environ['PIKA_EXCHANGE'], routing_key='classification-result',
|
||||||
|
body=json.dumps(results).encode("utf-8"))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
Reference in New Issue
Block a user