This commit is contained in:
		@@ -12,7 +12,7 @@ import numpy
 | 
			
		||||
import matplotlib.pyplot
 | 
			
		||||
from tensorflow.keras.models import model_from_json
 | 
			
		||||
from tensorflow.keras import optimizers
 | 
			
		||||
from tensorflow.keras_preprocessing.image import ImageDataGenerator
 | 
			
		||||
from keras_preprocessing.image import ImageDataGenerator
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Classifier(object):
 | 
			
		||||
 
 | 
			
		||||
@@ -4,8 +4,8 @@ import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Config:
 | 
			
		||||
    PIKA_URL = os.environ['PIKA_URL']
 | 
			
		||||
    PIKA_OUTPUT_EXCHANGE = os.environ['PIKA_OUTPUT_EXCHANGE']
 | 
			
		||||
    PIKA_URL = os.environ.get('PIKA_URL')
 | 
			
		||||
    PIKA_OUTPUT_EXCHANGE = os.environ.get('PIKA_OUTPUT_EXCHANGE')
 | 
			
		||||
    PIKA_INPUT_EXCHANGE = os.environ.get('PIKA_INPUT_EXCHANGE', 'sample-ready')
 | 
			
		||||
 | 
			
		||||
    MODEL_INFO_URL = os.environ.get("MODEL_INFO_URL", "http://model-service/model/cnn/$default")
 | 
			
		||||
 
 | 
			
		||||
@@ -17,6 +17,7 @@ import sentry_sdk
 | 
			
		||||
from config import Config
 | 
			
		||||
 | 
			
		||||
from magic_doer import MagicDoer
 | 
			
		||||
from classifier_cache import ClassifierCache
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def message_callback(channel, method, properties, body):
 | 
			
		||||
@@ -58,15 +59,10 @@ def message_callback(channel, method, properties, body):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
    # setup logging
 | 
			
		||||
 | 
			
		||||
    logging.basicConfig(
 | 
			
		||||
        stream=sys.stdout,
 | 
			
		||||
        format="%(asctime)s - %(name)s [%(levelname)s]: %(message)s",
 | 
			
		||||
        level=Config.LOG_LEVEL
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # setup observability stuffs
 | 
			
		||||
    if (not Config.PIKA_URL) or (not Config.PIKA_OUTPUT_EXCHANGE):
 | 
			
		||||
        logging.error("Mandatory config parameters unset: PIKA_URL or PIKA_OUTPUT_EXCHANGE")
 | 
			
		||||
        raise KeyError
 | 
			
		||||
 | 
			
		||||
    if Config.SENTRY_DSN:
 | 
			
		||||
        sentry_logging = LoggingIntegration(
 | 
			
		||||
@@ -112,5 +108,24 @@ def main():
 | 
			
		||||
    opentracing.tracer.close()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_loader():
 | 
			
		||||
    logging.info("Testing if model loading works...")
 | 
			
		||||
    cc = ClassifierCache(Config.MODEL_INFO_URL)
 | 
			
		||||
    details, classifier = cc.get_default_classifier()
 | 
			
		||||
    logging.info(f"Loaded classifier: {classifier}")
 | 
			
		||||
    logging.info(f"Details: {details}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    main()
 | 
			
		||||
    # setup logging
 | 
			
		||||
 | 
			
		||||
    logging.basicConfig(
 | 
			
		||||
        stream=sys.stdout,
 | 
			
		||||
        format="%(asctime)s - %(name)s [%(levelname)s]: %(message)s",
 | 
			
		||||
        level=Config.LOG_LEVEL
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    if '--test-loader' in sys.argv:
 | 
			
		||||
        test_loader()
 | 
			
		||||
    else:
 | 
			
		||||
        main()
 | 
			
		||||
 
 | 
			
		||||
@@ -2,7 +2,7 @@ sentry_sdk
 | 
			
		||||
pika
 | 
			
		||||
requests
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
h5py < 3.0.0
 | 
			
		||||
librosa
 | 
			
		||||
keras
 | 
			
		||||
numpy
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user