diff --git a/cnn_classification_service/classifier_cache.py b/cnn_classification_service/classifier_cache.py index b4f14e5..aeec0e5 100644 --- a/cnn_classification_service/classifier_cache.py +++ b/cnn_classification_service/classifier_cache.py @@ -11,8 +11,8 @@ from cnn_classifier import Classifier class ClassifierCache: - def __init__(self, default_model_info_url: str = "http://model-service/model/cnn/$default"): - self._default_model_info_url = default_model_info_url + def __init__(self, model_info_url: str): + self._model_info_url = model_info_url self._current_model_details = None # Should never be equal to the default model id self._current_classifier = None # Latest classifier is a classifier that uses the $default model @@ -37,7 +37,7 @@ class ClassifierCache: logging.debug("Fetching model file...") r = self._session.get( # Fun fact: urljoin is used to support both relative and absolute urls - urljoin(self._default_model_info_url, model_file_url) + urljoin(self._model_info_url, model_file_url) ) r.raise_for_status() @@ -48,7 +48,7 @@ class ClassifierCache: logging.debug("Fetching weights file...") r = self._session.get( - urljoin(self._default_model_info_url, weights_file_url) + urljoin(self._model_info_url, weights_file_url) ) r.raise_for_status() @@ -62,7 +62,7 @@ class ClassifierCache: def get_default_classifier(self) -> Tuple[dict, Classifier]: logging.debug("Fetching model info...") - r = self._session.get(self._default_model_info_url) + r = self._session.get(self._model_info_url) r.raise_for_status() model_details = r.json() diff --git a/cnn_classification_service/config.py b/cnn_classification_service/config.py new file mode 100644 index 0000000..154393f --- /dev/null +++ b/cnn_classification_service/config.py @@ -0,0 +1,15 @@ +import os + + +class Config: + PIKA_URL = os.environ['PIKA_URL'] + PIKA_OUTPUT_EXCHANGE = os.environ['PIKA_OUTPUT_EXCHANGE'] + PIKA_INPUT_EXCHANGE = os.environ['PIKA_INPUT_EXCHANGE'] + + MODEL_INFO_URL = os.environ.get("MODEL_INFO_URL", "http://model-service/model/cnn/$default") + STORAGE_SERVICE_URL = os.environ.get("STORAGE_SERVICE_URL", "http://storage-service/") + + SENTRY_DSN = os.environ.get("SENTRY_DSN") + + RELEASE_ID = os.environ.get('RELEASE_ID', 'test') + RELEASEMODE = os.environ.get('RELEASEMODE', 'dev') diff --git a/cnn_classification_service/magic_doer.py b/cnn_classification_service/magic_doer.py index c0e19b9..0ed5004 100644 --- a/cnn_classification_service/magic_doer.py +++ b/cnn_classification_service/magic_doer.py @@ -6,13 +6,13 @@ import requests import time from urllib.parse import urljoin +from config import Config + from classifier_cache import ClassifierCache class MagicDoer: - classifier_cache = ClassifierCache( - os.environ.get("MODEL_INFO_URL", "http://model-service/model/cnn/$default") - ) + classifier_cache = ClassifierCache(Config.MODEL_INFO_URL) @classmethod def run_everything(cls, parameters: dict) -> dict: @@ -22,8 +22,7 @@ class MagicDoer: try: # Download Sample - storage_service_url = os.environ.get("STORAGE_SERVICE_URL", "http://storage-service/") - object_path = urljoin(storage_service_url, f"object/{tag}") + object_path = urljoin(Config.STORAGE_SERVICE_URL, f"object/{tag}") logging.info(f"Downloading sample: {tag} from {object_path}") r = requests.get(object_path) diff --git a/cnn_classification_service/main.py b/cnn_classification_service/main.py index 4842afe..8acb2df 100644 --- a/cnn_classification_service/main.py +++ b/cnn_classification_service/main.py @@ -9,6 +9,8 @@ from sentry_sdk.integrations.logging import LoggingIntegration from sentry_sdk import start_transaction import sentry_sdk +from config import Config + from magic_doer import MagicDoer @@ -24,7 +26,7 @@ def message_callback(channel, method, properties, body): if results: channel.basic_publish( - exchange=os.environ['PIKA_OUTPUT_EXCHANGE'], + exchange=Config.PIKA_OUTPUT_EXCHANGE, routing_key='classification-result', body=json.dumps(results).encode("utf-8") ) @@ -38,31 +40,30 @@ def main(): level=logging.DEBUG if '--debug' in sys.argv else logging.INFO ) - SENTRY_DSN = os.environ.get("SENTRY_DSN") - if SENTRY_DSN: + if Config.SENTRY_DSN: sentry_logging = LoggingIntegration( level=logging.DEBUG, # Capture info and above as breadcrumbs event_level=logging.ERROR # Send errors as events ) sentry_sdk.init( - dsn=SENTRY_DSN, + dsn=Config.SENTRY_DSN, integrations=[sentry_logging], traces_sample_rate=1.0, send_default_pii=True, - release=os.environ.get('RELEASE_ID', 'test'), - environment=os.environ.get('RELEASEMODE', 'dev'), + release=Config.RELEASE_ID, + environment=Config.RELEASEMODE, _experiments={"auto_enabling_integrations": True} ) logging.info("Connecting to MQ service...") - connection = pika.BlockingConnection(pika.connection.URLParameters(os.environ['PIKA_URL'])) + connection = pika.BlockingConnection(pika.connection.URLParameters(Config.PIKA_URL)) channel = connection.channel() - channel.exchange_declare(exchange=os.environ['PIKA_INPUT_EXCHANGE'], exchange_type='direct') + channel.exchange_declare(exchange=Config.PIKA_INPUT_EXCHANGE, exchange_type='direct') queue_declare_result = channel.queue_declare(queue='cnnqueue', exclusive=False) queue_name = queue_declare_result.method.queue - channel.queue_bind(exchange=os.environ['PIKA_INPUT_EXCHANGE'], routing_key='feature', queue=queue_name) + channel.queue_bind(exchange=Config.PIKA_INPUT_EXCHANGE, routing_key='feature', queue=queue_name) channel.basic_qos(prefetch_count=1)