2020-10-06 00:41:54 +02:00
|
|
|
#!/usr/bin/env python3
|
|
|
|
import os
|
|
|
|
import logging
|
2021-07-30 12:17:18 +02:00
|
|
|
from requests_opentracing import SessionTracing
|
|
|
|
import opentracing # ez kell ide hogy a session tracer jolegyen
|
2020-10-06 00:41:54 +02:00
|
|
|
import requests
|
|
|
|
import tempfile
|
|
|
|
from typing import Tuple
|
|
|
|
from urllib.parse import urljoin
|
|
|
|
|
|
|
|
from cnn_classifier import Classifier
|
2021-08-04 14:49:40 +02:00
|
|
|
from config import Config
|
|
|
|
import time
|
2020-10-06 00:41:54 +02:00
|
|
|
|
|
|
|
|
|
|
|
class ClassifierCache:
|
|
|
|
|
2021-07-30 11:43:38 +02:00
|
|
|
def __init__(self, model_info_url: str):
|
|
|
|
self._model_info_url = model_info_url
|
2020-10-06 00:41:54 +02:00
|
|
|
|
|
|
|
self._current_model_details = None # Should never be equal to the default model id
|
|
|
|
self._current_classifier = None # Latest classifier is a classifier that uses the $default model
|
|
|
|
self._downloaded_files = []
|
|
|
|
|
2021-08-04 14:49:40 +02:00
|
|
|
self._last_fetch_time = 0
|
|
|
|
|
2021-07-30 12:17:18 +02:00
|
|
|
self._session = SessionTracing(propagate=True)
|
2020-10-06 00:41:54 +02:00
|
|
|
|
|
|
|
def _cleanup(self):
|
|
|
|
self._current_classifier = None
|
|
|
|
self._current_model_details = None
|
|
|
|
for file in self._downloaded_files:
|
|
|
|
try:
|
|
|
|
os.unlink(file)
|
|
|
|
except FileNotFoundError:
|
|
|
|
pass
|
|
|
|
|
|
|
|
self._downloaded_files = []
|
|
|
|
|
|
|
|
def _download_and_load_model(self, model_file_url: str, weights_file_url: str):
|
|
|
|
model_file_handle, model_file_path = tempfile.mkstemp(suffix=".json")
|
|
|
|
weights_file_handle, weights_file_path = tempfile.mkstemp(suffix=".h5")
|
|
|
|
|
|
|
|
logging.debug("Fetching model file...")
|
|
|
|
r = self._session.get( # Fun fact: urljoin is used to support both relative and absolute urls
|
2021-07-30 11:43:38 +02:00
|
|
|
urljoin(self._model_info_url, model_file_url)
|
2020-10-06 00:41:54 +02:00
|
|
|
)
|
|
|
|
r.raise_for_status()
|
|
|
|
|
|
|
|
with open(model_file_handle, 'wb') as f:
|
|
|
|
f.write(r.content)
|
|
|
|
|
|
|
|
self._downloaded_files.append(model_file_path)
|
|
|
|
|
|
|
|
logging.debug("Fetching weights file...")
|
|
|
|
r = self._session.get(
|
2021-07-30 11:43:38 +02:00
|
|
|
urljoin(self._model_info_url, weights_file_url)
|
2020-10-06 00:41:54 +02:00
|
|
|
)
|
|
|
|
r.raise_for_status()
|
|
|
|
|
|
|
|
with open(weights_file_handle, 'wb') as f:
|
|
|
|
f.write(r.content)
|
|
|
|
|
|
|
|
self._downloaded_files.append(weights_file_path)
|
|
|
|
|
|
|
|
# magic happens here
|
|
|
|
self._current_classifier = Classifier(model_file_path, weights_file_path)
|
|
|
|
|
|
|
|
def get_default_classifier(self) -> Tuple[dict, Classifier]:
|
|
|
|
|
2021-08-04 14:49:40 +02:00
|
|
|
if ((time.time() - self._last_fetch_time) > Config.MODEL_CACHE_LIFETIME_SEC) or \
|
|
|
|
(not self._current_model_details):
|
|
|
|
|
|
|
|
logging.debug("Fetching model info...")
|
|
|
|
r = self._session.get(self._model_info_url)
|
|
|
|
r.raise_for_status()
|
|
|
|
self._last_fetch_time = time.time()
|
|
|
|
|
|
|
|
model_details = r.json()
|
|
|
|
|
2021-08-04 15:50:36 +02:00
|
|
|
if ('id' not in model_details) or not model_details['id']:
|
|
|
|
raise KeyError("Model info is invalid!")
|
|
|
|
|
|
|
|
current_model_id = None
|
|
|
|
if self._current_model_details:
|
|
|
|
current_model_id = self._current_model_details['id']
|
|
|
|
|
|
|
|
if current_model_id != model_details['id']:
|
|
|
|
logging.info(
|
|
|
|
f"Model needs to be loaded (local: {current_model_id}; modelsvc def: {model_details['id']})"
|
|
|
|
)
|
2021-08-04 14:49:40 +02:00
|
|
|
# If the currently loaded model is not the default... then load it
|
|
|
|
self._cleanup() # delete/unload everything
|
|
|
|
self._download_and_load_model(model_details['files']['model'], model_details['files']['weights'])
|
|
|
|
self._current_model_details = model_details
|
|
|
|
else:
|
2021-08-04 15:50:36 +02:00
|
|
|
logging.debug(
|
|
|
|
f"Currently loaded model seems up to date ({current_model_id} == {model_details['id']})"
|
|
|
|
)
|
2020-10-06 00:41:54 +02:00
|
|
|
|
2021-08-04 14:49:40 +02:00
|
|
|
else:
|
|
|
|
logging.debug("Cache is still valid. Not fetching model info")
|
2020-10-06 00:41:54 +02:00
|
|
|
|
|
|
|
return self._current_model_details, self._current_classifier
|