Updated for new model service api
	
		
			
	
		
	
	
		
	
		
			All checks were successful
		
		
	
	
		
			
				
	
				continuous-integration/drone/push Build is passing
				
			
		
		
	
	
				
					
				
			
		
			All checks were successful
		
		
	
	continuous-integration/drone/push Build is passing
				
			This commit is contained in:
		@@ -4,6 +4,8 @@ import logging
 | 
			
		||||
import tempfile
 | 
			
		||||
import requests
 | 
			
		||||
 | 
			
		||||
from urllib.parse import urljoin
 | 
			
		||||
 | 
			
		||||
from cnn_classifier import Classifier
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@@ -16,6 +18,8 @@ def run_everything(parameters: dict):
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
 | 
			
		||||
        # Download Sample
 | 
			
		||||
 | 
			
		||||
        logging.info(f"Downloading sample: {tag}")
 | 
			
		||||
        r = requests.get(f"http://storage-service/object/{tag}")
 | 
			
		||||
        with open(sample_file_handle, 'wb') as f:
 | 
			
		||||
@@ -23,13 +27,25 @@ def run_everything(parameters: dict):
 | 
			
		||||
 | 
			
		||||
        logging.debug(f"Downloaded sample to {sample_file_path}")
 | 
			
		||||
 | 
			
		||||
        r = requests.get(f"http://model-service/model/cnn/$default")
 | 
			
		||||
        # Download model
 | 
			
		||||
 | 
			
		||||
        model_root_url = "http://model-service/model/cnn/$default"
 | 
			
		||||
 | 
			
		||||
        logging.debug("Fetching model info...")
 | 
			
		||||
        r = requests.get(model_root_url)
 | 
			
		||||
        r.raise_for_status()
 | 
			
		||||
 | 
			
		||||
        model_details = r.json()
 | 
			
		||||
 | 
			
		||||
        logging.debug("Fetching model file...")
 | 
			
		||||
        r = requests.get(urljoin(model_root_url, model_details['files']['model']))  # Fun fact: this would support external urls
 | 
			
		||||
        r.raise_for_status()
 | 
			
		||||
 | 
			
		||||
        with open(model_file_handle, 'wb') as f:
 | 
			
		||||
            f.write(r.content)
 | 
			
		||||
 | 
			
		||||
        r = requests.get(f"http://model-service/model/cnn/$default?weights")
 | 
			
		||||
        logging.debug("Fetching weights file...")
 | 
			
		||||
        r = requests.get(urljoin(model_root_url, model_details['files']['weights']))
 | 
			
		||||
        r.raise_for_status()
 | 
			
		||||
 | 
			
		||||
        with open(weights_file_handle, 'wb') as f:
 | 
			
		||||
@@ -57,8 +73,8 @@ def run_everything(parameters: dict):
 | 
			
		||||
 | 
			
		||||
    response = {
 | 
			
		||||
        "tag": tag,
 | 
			
		||||
        "probability": 1.0 if results[0] == 'sturnus' else 0.0,
 | 
			
		||||
        "model": "TODO"
 | 
			
		||||
        "probability": 1.0 if results[0] == model_details['target_class_name'] else 0.0,
 | 
			
		||||
        "model": model_details['id']
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    logging.info(f"Classification done!")
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user