diff --git a/model_service/schemas/aimodel_schema.py b/model_service/schemas/aimodel_schema.py index a6c62e7..e675bd9 100644 --- a/model_service/schemas/aimodel_schema.py +++ b/model_service/schemas/aimodel_schema.py @@ -1,13 +1,16 @@ #!/usr/bin/env python3 from marshmallow import fields from marshmallow_sqlalchemy import ModelSchema -from model import AIModel +from marshmallow_enum import EnumField +from model import AIModel, AIModelType class AIModelSchema(ModelSchema): default = fields.Method("boolize_default", dump_only=True) + type = EnumField(AIModelType) + def boolize_default(self, ai_model) -> bool: return bool(ai_model.default) diff --git a/model_service/views/cnn_view.py b/model_service/views/cnn_view.py index 47f61c1..cfece6b 100644 --- a/model_service/views/cnn_view.py +++ b/model_service/views/cnn_view.py @@ -60,8 +60,6 @@ class CNNView(FlaskView): storage.connection.put_object(current_app.config['MINIO_CNN_BUCKET_NAME'], "weights/" + str(m.id), weights_file, weights_file.content_length, content_type=weights_file.content_type) - m = AIModel(id=info['id'], type=AIModelType.CNN) - db.session.add(m) db.session.commit()