diff --git a/cnn_classification_service/cnn_classifier.py b/cnn_classification_service/cnn_classifier.py index 617527d..fefecc3 100644 --- a/cnn_classification_service/cnn_classifier.py +++ b/cnn_classification_service/cnn_classifier.py @@ -12,7 +12,7 @@ import numpy import matplotlib.pyplot from tensorflow.keras.models import model_from_json from tensorflow.keras import optimizers -from tensorflow.keras_preprocessing.image import ImageDataGenerator +from keras_preprocessing.image import ImageDataGenerator class Classifier(object): diff --git a/cnn_classification_service/config.py b/cnn_classification_service/config.py index bf17a26..a0b9d6e 100644 --- a/cnn_classification_service/config.py +++ b/cnn_classification_service/config.py @@ -4,8 +4,8 @@ import logging class Config: - PIKA_URL = os.environ['PIKA_URL'] - PIKA_OUTPUT_EXCHANGE = os.environ['PIKA_OUTPUT_EXCHANGE'] + PIKA_URL = os.environ.get('PIKA_URL') + PIKA_OUTPUT_EXCHANGE = os.environ.get('PIKA_OUTPUT_EXCHANGE') PIKA_INPUT_EXCHANGE = os.environ.get('PIKA_INPUT_EXCHANGE', 'sample-ready') MODEL_INFO_URL = os.environ.get("MODEL_INFO_URL", "http://model-service/model/cnn/$default") diff --git a/cnn_classification_service/main.py b/cnn_classification_service/main.py index 2828a2c..b218a67 100644 --- a/cnn_classification_service/main.py +++ b/cnn_classification_service/main.py @@ -17,6 +17,7 @@ import sentry_sdk from config import Config from magic_doer import MagicDoer +from classifier_cache import ClassifierCache def message_callback(channel, method, properties, body): @@ -58,15 +59,10 @@ def message_callback(channel, method, properties, body): def main(): - # setup logging - - logging.basicConfig( - stream=sys.stdout, - format="%(asctime)s - %(name)s [%(levelname)s]: %(message)s", - level=Config.LOG_LEVEL - ) - # setup observability stuffs + if (not Config.PIKA_URL) or (not Config.PIKA_OUTPUT_EXCHANGE): + logging.error("Mandatory config parameters unset: PIKA_URL or PIKA_OUTPUT_EXCHANGE") + raise KeyError if Config.SENTRY_DSN: sentry_logging = LoggingIntegration( @@ -112,5 +108,24 @@ def main(): opentracing.tracer.close() +def test_loader(): + logging.info("Testing if model loading works...") + cc = ClassifierCache(Config.MODEL_INFO_URL) + details, classifier = cc.get_default_classifier() + logging.info(f"Loaded classifier: {classifier}") + logging.info(f"Details: {details}") + + if __name__ == '__main__': - main() + # setup logging + + logging.basicConfig( + stream=sys.stdout, + format="%(asctime)s - %(name)s [%(levelname)s]: %(message)s", + level=Config.LOG_LEVEL + ) + + if '--test-loader' in sys.argv: + test_loader() + else: + main() diff --git a/requirements.txt b/requirements.txt index 3fa5213..b1a6947 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ sentry_sdk pika requests - +h5py < 3.0.0 librosa keras numpy