This commit is contained in:
parent
d712b6e3db
commit
b37b6e2477
@ -12,7 +12,7 @@ import numpy
|
|||||||
import matplotlib.pyplot
|
import matplotlib.pyplot
|
||||||
from tensorflow.keras.models import model_from_json
|
from tensorflow.keras.models import model_from_json
|
||||||
from tensorflow.keras import optimizers
|
from tensorflow.keras import optimizers
|
||||||
from tensorflow.keras_preprocessing.image import ImageDataGenerator
|
from keras_preprocessing.image import ImageDataGenerator
|
||||||
|
|
||||||
|
|
||||||
class Classifier(object):
|
class Classifier(object):
|
||||||
|
@ -4,8 +4,8 @@ import logging
|
|||||||
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
PIKA_URL = os.environ['PIKA_URL']
|
PIKA_URL = os.environ.get('PIKA_URL')
|
||||||
PIKA_OUTPUT_EXCHANGE = os.environ['PIKA_OUTPUT_EXCHANGE']
|
PIKA_OUTPUT_EXCHANGE = os.environ.get('PIKA_OUTPUT_EXCHANGE')
|
||||||
PIKA_INPUT_EXCHANGE = os.environ.get('PIKA_INPUT_EXCHANGE', 'sample-ready')
|
PIKA_INPUT_EXCHANGE = os.environ.get('PIKA_INPUT_EXCHANGE', 'sample-ready')
|
||||||
|
|
||||||
MODEL_INFO_URL = os.environ.get("MODEL_INFO_URL", "http://model-service/model/cnn/$default")
|
MODEL_INFO_URL = os.environ.get("MODEL_INFO_URL", "http://model-service/model/cnn/$default")
|
||||||
|
@ -17,6 +17,7 @@ import sentry_sdk
|
|||||||
from config import Config
|
from config import Config
|
||||||
|
|
||||||
from magic_doer import MagicDoer
|
from magic_doer import MagicDoer
|
||||||
|
from classifier_cache import ClassifierCache
|
||||||
|
|
||||||
|
|
||||||
def message_callback(channel, method, properties, body):
|
def message_callback(channel, method, properties, body):
|
||||||
@ -58,15 +59,10 @@ def message_callback(channel, method, properties, body):
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# setup logging
|
|
||||||
|
|
||||||
logging.basicConfig(
|
|
||||||
stream=sys.stdout,
|
|
||||||
format="%(asctime)s - %(name)s [%(levelname)s]: %(message)s",
|
|
||||||
level=Config.LOG_LEVEL
|
|
||||||
)
|
|
||||||
|
|
||||||
# setup observability stuffs
|
# setup observability stuffs
|
||||||
|
if (not Config.PIKA_URL) or (not Config.PIKA_OUTPUT_EXCHANGE):
|
||||||
|
logging.error("Mandatory config parameters unset: PIKA_URL or PIKA_OUTPUT_EXCHANGE")
|
||||||
|
raise KeyError
|
||||||
|
|
||||||
if Config.SENTRY_DSN:
|
if Config.SENTRY_DSN:
|
||||||
sentry_logging = LoggingIntegration(
|
sentry_logging = LoggingIntegration(
|
||||||
@ -112,5 +108,24 @@ def main():
|
|||||||
opentracing.tracer.close()
|
opentracing.tracer.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_loader():
|
||||||
|
logging.info("Testing if model loading works...")
|
||||||
|
cc = ClassifierCache(Config.MODEL_INFO_URL)
|
||||||
|
details, classifier = cc.get_default_classifier()
|
||||||
|
logging.info(f"Loaded classifier: {classifier}")
|
||||||
|
logging.info(f"Details: {details}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
# setup logging
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
stream=sys.stdout,
|
||||||
|
format="%(asctime)s - %(name)s [%(levelname)s]: %(message)s",
|
||||||
|
level=Config.LOG_LEVEL
|
||||||
|
)
|
||||||
|
|
||||||
|
if '--test-loader' in sys.argv:
|
||||||
|
test_loader()
|
||||||
|
else:
|
||||||
|
main()
|
||||||
|
@ -2,7 +2,7 @@ sentry_sdk
|
|||||||
pika
|
pika
|
||||||
requests
|
requests
|
||||||
|
|
||||||
|
h5py < 3.0.0
|
||||||
librosa
|
librosa
|
||||||
keras
|
keras
|
||||||
numpy
|
numpy
|
||||||
|
Loading…
Reference in New Issue
Block a user