diff --git a/model_service/model/aimodel.py b/model_service/model/aimodel.py index 6e86115..e6d60b8 100644 --- a/model_service/model/aimodel.py +++ b/model_service/model/aimodel.py @@ -5,12 +5,16 @@ from sqlalchemy.dialects.postgresql import UUID import uuid import enum + class AIModelType(enum.Enum): SVM = 1 CNN = 2 + class AIModel(db.Model): id = db.Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, unique=True, nullable=False) timestamp = db.Column(db.TIMESTAMP, nullable=False, server_default=func.now()) type = db.Column(db.Enum(AIModelType), nullable=False) + + target_class_name = db.Column(db.String(50), nullable=False) diff --git a/model_service/schemas/info_schema.py b/model_service/schemas/info_schema.py index 8a515b8..c131c20 100644 --- a/model_service/schemas/info_schema.py +++ b/model_service/schemas/info_schema.py @@ -8,6 +8,7 @@ import uuid class InfoSchema(Schema): id = fields.UUID(default=uuid.uuid4, missing=uuid.uuid4) + target_class_name = fields.String() @classmethod # This threats none values as missing def get_attribute(cls, attr, obj, default): diff --git a/model_service/views/cnn_view.py b/model_service/views/cnn_view.py index cfece6b..67d4e82 100644 --- a/model_service/views/cnn_view.py +++ b/model_service/views/cnn_view.py @@ -48,10 +48,7 @@ class CNNView(FlaskView): ensure_buckets() # Create the entry in the db - if info['id']: - m = AIModel(id=info['id'], type=AIModelType.CNN) - else: - m = AIModel(type=AIModelType.CNN) + m = AIModel(id=info['id'], type=AIModelType.CNN, target_class_name=info['target_class_name']) # Put files into MinIO storage.connection.put_object(current_app.config['MINIO_CNN_BUCKET_NAME'], "model/" + str(m.id), model_file, diff --git a/model_service/views/svm_view.py b/model_service/views/svm_view.py index 81456d4..59433dc 100644 --- a/model_service/views/svm_view.py +++ b/model_service/views/svm_view.py @@ -81,7 +81,7 @@ class SVMView(FlaskView): os.remove(temp_model_filename) os.remove(temp_means_filename) - m = AIModel(id=info['id'], type=AIModelType.SVM) + m = AIModel(id=info['id'], type=AIModelType.SVM, target_class_name=info['target_class_name']) d = SVMDetails( aimodel=m,