Updated for new model service api
All checks were successful
continuous-integration/drone/push Build is passing
All checks were successful
continuous-integration/drone/push Build is passing
This commit is contained in:
parent
4d7847d58e
commit
4003aa73ac
@ -4,6 +4,8 @@ import logging
|
|||||||
import tempfile
|
import tempfile
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
from cnn_classifier import Classifier
|
from cnn_classifier import Classifier
|
||||||
|
|
||||||
|
|
||||||
@ -16,6 +18,8 @@ def run_everything(parameters: dict):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
|
# Download Sample
|
||||||
|
|
||||||
logging.info(f"Downloading sample: {tag}")
|
logging.info(f"Downloading sample: {tag}")
|
||||||
r = requests.get(f"http://storage-service/object/{tag}")
|
r = requests.get(f"http://storage-service/object/{tag}")
|
||||||
with open(sample_file_handle, 'wb') as f:
|
with open(sample_file_handle, 'wb') as f:
|
||||||
@ -23,13 +27,25 @@ def run_everything(parameters: dict):
|
|||||||
|
|
||||||
logging.debug(f"Downloaded sample to {sample_file_path}")
|
logging.debug(f"Downloaded sample to {sample_file_path}")
|
||||||
|
|
||||||
r = requests.get(f"http://model-service/model/cnn/$default")
|
# Download model
|
||||||
|
|
||||||
|
model_root_url = "http://model-service/model/cnn/$default"
|
||||||
|
|
||||||
|
logging.debug("Fetching model info...")
|
||||||
|
r = requests.get(model_root_url)
|
||||||
|
r.raise_for_status()
|
||||||
|
|
||||||
|
model_details = r.json()
|
||||||
|
|
||||||
|
logging.debug("Fetching model file...")
|
||||||
|
r = requests.get(urljoin(model_root_url, model_details['files']['model'])) # Fun fact: this would support external urls
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
|
|
||||||
with open(model_file_handle, 'wb') as f:
|
with open(model_file_handle, 'wb') as f:
|
||||||
f.write(r.content)
|
f.write(r.content)
|
||||||
|
|
||||||
r = requests.get(f"http://model-service/model/cnn/$default?weights")
|
logging.debug("Fetching weights file...")
|
||||||
|
r = requests.get(urljoin(model_root_url, model_details['files']['weights']))
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
|
|
||||||
with open(weights_file_handle, 'wb') as f:
|
with open(weights_file_handle, 'wb') as f:
|
||||||
@ -57,8 +73,8 @@ def run_everything(parameters: dict):
|
|||||||
|
|
||||||
response = {
|
response = {
|
||||||
"tag": tag,
|
"tag": tag,
|
||||||
"probability": 1.0 if results[0] == 'sturnus' else 0.0,
|
"probability": 1.0 if results[0] == model_details['target_class_name'] else 0.0,
|
||||||
"model": "TODO"
|
"model": model_details['id']
|
||||||
}
|
}
|
||||||
|
|
||||||
logging.info(f"Classification done!")
|
logging.info(f"Classification done!")
|
||||||
|
Loading…
Reference in New Issue
Block a user