cnn-classification-service/cnn_classification_service/classifier_cache.py

79 lines
2.7 KiB
Python
Raw Normal View History

2020-10-06 00:41:54 +02:00
#!/usr/bin/env python3
import os
import logging
2021-07-30 12:17:18 +02:00
from requests_opentracing import SessionTracing
import opentracing # ez kell ide hogy a session tracer jolegyen
2020-10-06 00:41:54 +02:00
import requests
import tempfile
from typing import Tuple
from urllib.parse import urljoin
from cnn_classifier import Classifier
class ClassifierCache:
2021-07-30 11:43:38 +02:00
def __init__(self, model_info_url: str):
self._model_info_url = model_info_url
2020-10-06 00:41:54 +02:00
self._current_model_details = None # Should never be equal to the default model id
self._current_classifier = None # Latest classifier is a classifier that uses the $default model
self._downloaded_files = []
2021-07-30 12:17:18 +02:00
self._session = SessionTracing(propagate=True)
2020-10-06 00:41:54 +02:00
def _cleanup(self):
self._current_classifier = None
self._current_model_details = None
for file in self._downloaded_files:
try:
os.unlink(file)
except FileNotFoundError:
pass
self._downloaded_files = []
def _download_and_load_model(self, model_file_url: str, weights_file_url: str):
model_file_handle, model_file_path = tempfile.mkstemp(suffix=".json")
weights_file_handle, weights_file_path = tempfile.mkstemp(suffix=".h5")
logging.debug("Fetching model file...")
r = self._session.get( # Fun fact: urljoin is used to support both relative and absolute urls
2021-07-30 11:43:38 +02:00
urljoin(self._model_info_url, model_file_url)
2020-10-06 00:41:54 +02:00
)
r.raise_for_status()
with open(model_file_handle, 'wb') as f:
f.write(r.content)
self._downloaded_files.append(model_file_path)
logging.debug("Fetching weights file...")
r = self._session.get(
2021-07-30 11:43:38 +02:00
urljoin(self._model_info_url, weights_file_url)
2020-10-06 00:41:54 +02:00
)
r.raise_for_status()
with open(weights_file_handle, 'wb') as f:
f.write(r.content)
self._downloaded_files.append(weights_file_path)
# magic happens here
self._current_classifier = Classifier(model_file_path, weights_file_path)
def get_default_classifier(self) -> Tuple[dict, Classifier]:
logging.debug("Fetching model info...")
2021-07-30 11:43:38 +02:00
r = self._session.get(self._model_info_url)
2020-10-06 00:41:54 +02:00
r.raise_for_status()
model_details = r.json()
if (not self._current_model_details) or (self._current_model_details['id'] != model_details['id']):
# If the currently loaded model is not the default... then load it
self._cleanup() # delete/unload everything
self._download_and_load_model(model_details['files']['model'], model_details['files']['weights'])
self._current_model_details = model_details
return self._current_model_details, self._current_classifier