Implemented classifier cache
All checks were successful
continuous-integration/drone/push Build is passing
All checks were successful
continuous-integration/drone/push Build is passing
This commit is contained in:
@@ -0,0 +1,76 @@
|
||||
#!/usr/bin/env python3
|
||||
import os
|
||||
import logging
|
||||
import requests
|
||||
import tempfile
|
||||
from typing import Tuple
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from cnn_classifier import Classifier
|
||||
|
||||
|
||||
class ClassifierCache:
|
||||
|
||||
def __init__(self, default_model_info_url: str = "http://model-service/model/cnn/$default"):
|
||||
self._default_model_info_url = default_model_info_url
|
||||
|
||||
self._current_model_details = None # Should never be equal to the default model id
|
||||
self._current_classifier = None # Latest classifier is a classifier that uses the $default model
|
||||
self._downloaded_files = []
|
||||
|
||||
self._session = requests.Session()
|
||||
|
||||
def _cleanup(self):
|
||||
self._current_classifier = None
|
||||
self._current_model_details = None
|
||||
for file in self._downloaded_files:
|
||||
try:
|
||||
os.unlink(file)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
self._downloaded_files = []
|
||||
|
||||
def _download_and_load_model(self, model_file_url: str, weights_file_url: str):
|
||||
model_file_handle, model_file_path = tempfile.mkstemp(suffix=".json")
|
||||
weights_file_handle, weights_file_path = tempfile.mkstemp(suffix=".h5")
|
||||
|
||||
logging.debug("Fetching model file...")
|
||||
r = self._session.get( # Fun fact: urljoin is used to support both relative and absolute urls
|
||||
urljoin(self._default_model_info_url, model_file_url)
|
||||
)
|
||||
r.raise_for_status()
|
||||
|
||||
with open(model_file_handle, 'wb') as f:
|
||||
f.write(r.content)
|
||||
|
||||
self._downloaded_files.append(model_file_path)
|
||||
|
||||
logging.debug("Fetching weights file...")
|
||||
r = self._session.get(
|
||||
urljoin(self._default_model_info_url, weights_file_url)
|
||||
)
|
||||
r.raise_for_status()
|
||||
|
||||
with open(weights_file_handle, 'wb') as f:
|
||||
f.write(r.content)
|
||||
|
||||
self._downloaded_files.append(weights_file_path)
|
||||
|
||||
# magic happens here
|
||||
self._current_classifier = Classifier(model_file_path, weights_file_path)
|
||||
|
||||
def get_default_classifier(self) -> Tuple[dict, Classifier]:
|
||||
logging.debug("Fetching model info...")
|
||||
r = self._session.get(self._default_model_info_url)
|
||||
r.raise_for_status()
|
||||
|
||||
model_details = r.json()
|
||||
|
||||
if (not self._current_model_details) or (self._current_model_details['id'] != model_details['id']):
|
||||
# If the currently loaded model is not the default... then load it
|
||||
self._cleanup() # delete/unload everything
|
||||
self._download_and_load_model(model_details['files']['model'], model_details['files']['weights'])
|
||||
self._current_model_details = model_details
|
||||
|
||||
return self._current_model_details, self._current_classifier
|
||||
|
||||
Reference in New Issue
Block a user