#!/usr/bin/env python3 import os import logging from requests_opentracing import SessionTracing import opentracing # ez kell ide hogy a session tracer jolegyen import requests import tempfile from typing import Tuple from urllib.parse import urljoin from cnn_classifier import Classifier from config import Config import time class ClassifierCache: def __init__(self, model_info_url: str): self._model_info_url = model_info_url self._current_model_details = None # Should never be equal to the default model id self._current_classifier = None # Latest classifier is a classifier that uses the $default model self._downloaded_files = [] self._last_fetch_time = 0 self._session = SessionTracing(propagate=True) def _cleanup(self): self._current_classifier = None self._current_model_details = None for file in self._downloaded_files: try: os.unlink(file) except FileNotFoundError: pass self._downloaded_files = [] def _download_and_load_model(self, model_file_url: str, weights_file_url: str): model_file_handle, model_file_path = tempfile.mkstemp(suffix=".json") weights_file_handle, weights_file_path = tempfile.mkstemp(suffix=".h5") logging.debug("Fetching model file...") r = self._session.get( # Fun fact: urljoin is used to support both relative and absolute urls urljoin(self._model_info_url, model_file_url) ) r.raise_for_status() with open(model_file_handle, 'wb') as f: f.write(r.content) self._downloaded_files.append(model_file_path) logging.debug("Fetching weights file...") r = self._session.get( urljoin(self._model_info_url, weights_file_url) ) r.raise_for_status() with open(weights_file_handle, 'wb') as f: f.write(r.content) self._downloaded_files.append(weights_file_path) # magic happens here self._current_classifier = Classifier(model_file_path, weights_file_path) def get_default_classifier(self) -> Tuple[dict, Classifier]: if ((time.time() - self._last_fetch_time) > Config.MODEL_CACHE_LIFETIME_SEC) or \ (not self._current_model_details): logging.debug("Fetching model info...") r = self._session.get(self._model_info_url) r.raise_for_status() self._last_fetch_time = time.time() model_details = r.json() if (not self._current_model_details) or (self._current_model_details['id'] != model_details['id']): logging.info(f"Model needs to be loaded (local: {self._current_model_details['id']} model service default: {model_details['id']})") # If the currently loaded model is not the default... then load it self._cleanup() # delete/unload everything self._download_and_load_model(model_details['files']['model'], model_details['files']['weights']) self._current_model_details = model_details else: logging.debug(f"Currently loaded model seems up to date ({self._current_model_details['id']} == {model_details['id']})") else: logging.debug("Cache is still valid. Not fetching model info") return self._current_model_details, self._current_classifier