From 4003aa73ac342e18c33887d1d2263b87a4ad8062 Mon Sep 17 00:00:00 2001 From: marcsello Date: Fri, 2 Oct 2020 03:59:09 +0200 Subject: [PATCH] Updated for new model service api --- cnn_classification_service/magic_doer.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/cnn_classification_service/magic_doer.py b/cnn_classification_service/magic_doer.py index 46045e3..a31b01b 100644 --- a/cnn_classification_service/magic_doer.py +++ b/cnn_classification_service/magic_doer.py @@ -4,6 +4,8 @@ import logging import tempfile import requests +from urllib.parse import urljoin + from cnn_classifier import Classifier @@ -16,6 +18,8 @@ def run_everything(parameters: dict): try: + # Download Sample + logging.info(f"Downloading sample: {tag}") r = requests.get(f"http://storage-service/object/{tag}") with open(sample_file_handle, 'wb') as f: @@ -23,13 +27,25 @@ def run_everything(parameters: dict): logging.debug(f"Downloaded sample to {sample_file_path}") - r = requests.get(f"http://model-service/model/cnn/$default") + # Download model + + model_root_url = "http://model-service/model/cnn/$default" + + logging.debug("Fetching model info...") + r = requests.get(model_root_url) + r.raise_for_status() + + model_details = r.json() + + logging.debug("Fetching model file...") + r = requests.get(urljoin(model_root_url, model_details['files']['model'])) # Fun fact: this would support external urls r.raise_for_status() with open(model_file_handle, 'wb') as f: f.write(r.content) - r = requests.get(f"http://model-service/model/cnn/$default?weights") + logging.debug("Fetching weights file...") + r = requests.get(urljoin(model_root_url, model_details['files']['weights'])) r.raise_for_status() with open(weights_file_handle, 'wb') as f: @@ -57,8 +73,8 @@ def run_everything(parameters: dict): response = { "tag": tag, - "probability": 1.0 if results[0] == 'sturnus' else 0.0, - "model": "TODO" + "probability": 1.0 if results[0] == model_details['target_class_name'] else 0.0, + "model": model_details['id'] } logging.info(f"Classification done!")