Added ability to store CNN models
continuous-integration/drone/push Build is passing Details

This commit is contained in:
Pünkösd Marcell 2020-07-28 20:19:52 +02:00
parent fe965706b5
commit 284dfce735
12 changed files with 224 additions and 67 deletions

View File

@ -11,7 +11,7 @@ from model import db
from utils import register_all_error_handlers, storage from utils import register_all_error_handlers, storage
# import views # import views
from views import ModelView from views import SVMView, CNNView
# Setup sentry # Setup sentry
SENTRY_DSN = os.environ.get("SENTRY_DSN") SENTRY_DSN = os.environ.get("SENTRY_DSN")
@ -32,8 +32,8 @@ app.config['SQLALCHEMY_DATABASE_URI'] = os.environ.get('DATABASE_URI', "sqlite:/
app.config['MINIO_ENDPOINT'] = os.environ['MINIO_ENDPOINT'] app.config['MINIO_ENDPOINT'] = os.environ['MINIO_ENDPOINT']
app.config['MINIO_ACCESS_KEY'] = os.environ['MINIO_ACCESS_KEY'] app.config['MINIO_ACCESS_KEY'] = os.environ['MINIO_ACCESS_KEY']
app.config['MINIO_SECRET_KEY'] = os.environ['MINIO_SECRET_KEY'] app.config['MINIO_SECRET_KEY'] = os.environ['MINIO_SECRET_KEY']
app.config['MINIO_MODEL_BUCKET_NAME'] = os.environ['MINIO_MODEL_BUCKET_NAME'] app.config['MINIO_SVM_BUCKET_NAME'] = os.environ['MINIO_SVM_BUCKET_NAME']
app.config['MINIO_MEANS_BUCKET_NAME'] = os.environ['MINIO_MEANS_BUCKET_NAME'] app.config['MINIO_CNN_BUCKET_NAME'] = os.environ['MINIO_CNN_BUCKET_NAME']
app.config['MINIO_SECURE'] = os.environ.get('MINIO_SECURE', False) app.config['MINIO_SECURE'] = os.environ.get('MINIO_SECURE', False)
app.config['MINIO_REGION'] = os.environ.get('MINIO_REGION', None) app.config['MINIO_REGION'] = os.environ.get('MINIO_REGION', None)
@ -53,8 +53,8 @@ with app.app_context():
register_all_error_handlers(app) register_all_error_handlers(app)
# register views # register views
for view in [ModelView]: for view in [SVMView, CNNView]:
view.register(app, trailing_slash=False) view.register(app, trailing_slash=False, route_prefix='/model')
# start debuggig if needed # start debuggig if needed
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,4 +1,5 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from .db import db from .db import db
from .default import Default from .default import Default
from .aimodel import AIModel from .aimodel import AIModel, AIModelType
from .svmdetails import SVMDetails

View File

@ -3,16 +3,14 @@ from .db import db
from sqlalchemy.sql import func from sqlalchemy.sql import func
from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql import UUID
import uuid import uuid
import enum
class AIModelType(enum.Enum):
SVM = 1
CNN = 2
class AIModel(db.Model): class AIModel(db.Model):
id = db.Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, unique=True, nullable=False) id = db.Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, unique=True, nullable=False)
timestamp = db.Column(db.TIMESTAMP, nullable=False, server_default=func.now()) timestamp = db.Column(db.TIMESTAMP, nullable=False, server_default=func.now())
# details type = db.Column(db.Enum(AIModelType), nullable=False)
mid_window = db.Column(db.Float)
mid_step = db.Column(db.Float)
short_window = db.Column(db.Float)
short_step = db.Column(db.Float)
compute_beat = db.Column(db.Boolean)
type = db.Column(db.String(15))

View File

@ -2,7 +2,11 @@
from .db import db from .db import db
from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql import UUID
from .aimodel import AIModelType
class Default(db.Model): class Default(db.Model):
aimodel_id = db.Column(UUID(as_uuid=True), db.ForeignKey("ai_model.id"), nullable=False, primary_key=True) type = db.Column(db.Enum(AIModelType), unique=True, nullable=False, primary_key=True)
aimodel_id = db.Column(UUID(as_uuid=True), db.ForeignKey("ai_model.id"), nullable=False)
aimodel = db.relationship("AIModel", backref=db.backref("default", lazy=True, cascade="save-update, merge, delete, delete-orphan")) aimodel = db.relationship("AIModel", backref=db.backref("default", lazy=True, cascade="save-update, merge, delete, delete-orphan"))

View File

@ -0,0 +1,16 @@
#!/usr/bin/env python3
from .db import db
from sqlalchemy.dialects.postgresql import UUID
class SVMDetails(db.Model):
aimodel_id = db.Column(UUID(as_uuid=True), db.ForeignKey("ai_model.id"), nullable=False, primary_key=True)
aimodel = db.relationship("AIModel", backref=db.backref("details", lazy=True, cascade="save-update, merge, delete, delete-orphan"))
# details
mid_window = db.Column(db.Float)
mid_step = db.Column(db.Float)
short_window = db.Column(db.Float)
short_step = db.Column(db.Float)
compute_beat = db.Column(db.Boolean)

View File

@ -8,7 +8,6 @@ import uuid
class InfoSchema(Schema): class InfoSchema(Schema):
id = fields.UUID(default=uuid.uuid4, missing=uuid.uuid4) id = fields.UUID(default=uuid.uuid4, missing=uuid.uuid4)
type = fields.String(validate=OneOf(['svm', 'svm_rbf', 'knn', 'extratrees', 'gradientboosting', 'randomforest']))
@classmethod # This threats none values as missing @classmethod # This threats none values as missing
def get_attribute(cls, attr, obj, default): def get_attribute(cls, attr, obj, default):

View File

@ -1,4 +1,4 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from .require_decorators import json_required from .require_decorators import json_required
from .error_handlers import register_all_error_handlers from .error_handlers import register_all_error_handlers
from .storage import storage from .storage import storage, ensure_buckets

View File

@ -1,3 +1,18 @@
from flask import current_app
from flask_minio import Minio from flask_minio import Minio
from minio.error import BucketAlreadyExists, BucketAlreadyOwnedByYou
storage = Minio() storage = Minio()
def ensure_buckets():
for bucket_name in [current_app.config['MINIO_SVM_BUCKET_NAME'],
current_app.config['MINIO_CNN_BUCKET_NAME']]:
try:
storage.connection.make_bucket(bucket_name)
except BucketAlreadyOwnedByYou:
pass
except BucketAlreadyExists:
pass
# Everything else should be raised

View File

@ -1,2 +1,3 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from .model_view import ModelView from .svm_view import SVMView
from .cnn_view import CNNView

View File

@ -0,0 +1,128 @@
#!/usr/bin/env python3
from flask import request, jsonify, current_app, abort, Response
from flask_classful import FlaskView, route
from model import db, Default, AIModel, AIModelType
from minio.error import NoSuchKey
from schemas import AIModelSchema, DefaultSchema, InfoSchema
from marshmallow.exceptions import ValidationError
from utils import json_required, storage, ensure_buckets
class CNNView(FlaskView):
aimodel_schema = AIModelSchema(many=False)
aimodels_schema = AIModelSchema(many=True, exclude=['timestamp', 'details'])
default_schema = DefaultSchema(many=False)
info_schema = InfoSchema(many=False)
def index(self):
models = AIModel.query.filter_by(type=AIModelType.CNN).all()
return jsonify(self.aimodels_schema.dump(models)), 200
def post(self):
# get important data from the request
try:
info = self.info_schema.loads(request.form.get('info'))
except ValidationError as e:
abort(400, str(e))
# check for conflict
m = AIModel.query.filter_by(id=info['id']).first()
if m:
abort(409)
# get and validate file
model_file = request.files['modelFile']
if model_file.content_length <= 0:
abort(411, f"Content length for modelFile is not a positive integer or missing.")
weights_file = request.files['weightsFile']
if weights_file.content_length <= 0:
abort(411, f"Content length for weightsFile is not a positive integer or missing.")
# create bucket if necessary
ensure_buckets()
# Put files into MinIO
storage.connection.put_object(current_app.config['MINIO_CNN_BUCKET_NAME'], "model/" + str(m.id), model_file,
model_file.content_length, content_type=model_file.content_type)
storage.connection.put_object(current_app.config['MINIO_CNN_BUCKET_NAME'], "weights/" + str(m.id), weights_file,
weights_file.content_length, content_type=weights_file.content_type)
m = AIModel(id=info['id'], type=AIModelType.CNN)
db.session.add(m)
db.session.commit()
return jsonify(self.aimodel_schema.dump(m)), 200
def get(self, _id: str):
if _id == "$default":
# TODO: Kitalálni, hogy inkább a latestestest-el térjen-e vissza
default = Default.query.filter_by(type=AIModelType.CNN).first_or_404()
m = default.aimodel
else:
m = AIModel.query.filter_by(type=AIModelType.CNN, id=_id).first_or_404()
if "weights" in request.args:
path = "weights/" + str(m.id)
else:
path = "model/" + str(m.id)
try:
data = storage.connection.get_object(current_app.config['MINIO_CNN_BUCKET_NAME'], path)
except NoSuchKey:
abort(500, "The ID is stored in the database but not int the Object Store")
return Response(data.stream(), mimetype=data.headers['Content-type'])
@route('<_id>/details')
def get_details(self, _id: str):
if _id == "$default":
# TODO: Kitalálni, hogy inkább a latestestest-el térjen-e vissza
default = Default.query.filter_by(type=AIModelType.CNN).first_or_404()
m = default.aimodel
else:
m = AIModel.query.filter_by(type=AIModelType.CNN, id=_id).first_or_404()
return jsonify(self.aimodel_schema.dump(m))
def delete(self, _id: str):
if _id == "$default":
# TODO: Kitalálni, hogy inkább a latestestest-el térjen-e vissza
default = Default.query.filter_by(type=AIModelType.CNN).first_or_404()
m = default.aimodel
else:
m = AIModel.query.filter_by(type=AIModelType.CNN, id=_id).first_or_404()
storage.connection.remove_object(current_app.config['MINIO_CNN_BUCKET_NAME'], "weights/" + str(m.id))
storage.connection.remove_object(current_app.config['MINIO_CNN_BUCKET_NAME'], "model/" + str(m.id))
db.session.delete(m)
db.session.commit()
return '', 204
@json_required
@route('$default', methods=['PUT'])
def put_default(self):
try:
req = self.default_schema.load(request.json)
except ValidationError as e:
abort(400, str(e))
m = AIModel.query.filter_by(type=AIModelType.CNN, id=req['id']).first_or_404()
Default.query.filter_by(type=AIModelType.CNN).delete()
new_default = Default(type=AIModelType.CNN, aimodel=m)
db.session.add(new_default)
db.session.commit()
return '', 204

View File

@ -3,36 +3,22 @@ import tempfile
import os import os
from flask import request, jsonify, current_app, abort, Response from flask import request, jsonify, current_app, abort, Response
from flask_classful import FlaskView, route from flask_classful import FlaskView, route
from model import db, Default, AIModel from model import db, Default, AIModel, AIModelType, SVMDetails
from minio.error import BucketAlreadyExists, BucketAlreadyOwnedByYou, ResponseError, NoSuchKey from minio.error import NoSuchKey
from schemas import AIModelSchema, DefaultSchema, InfoSchema from schemas import AIModelSchema, DefaultSchema, InfoSchema
from marshmallow.exceptions import ValidationError from marshmallow.exceptions import ValidationError
from utils import json_required, storage from utils import json_required, storage, ensure_buckets
from pyAudioAnalysis.audioTrainTest import load_model, load_model_knn from pyAudioAnalysis.audioTrainTest import load_model, load_model_knn
class ModelView(FlaskView): class SVMView(FlaskView):
aimodel_schema = AIModelSchema(many=False) aimodel_schema = AIModelSchema(many=False)
aimodels_schema = AIModelSchema(many=True, aimodels_schema = AIModelSchema(many=True, exclude=['timestamp', 'details'])
exclude=['timestamp', 'mid_window', 'mid_step', 'short_window', 'short_step',
'compute_beat'])
default_schema = DefaultSchema(many=False) default_schema = DefaultSchema(many=False)
info_schema = InfoSchema(many=False) info_schema = InfoSchema(many=False)
def _ensure_buckets(self):
for bucket_name in [current_app.config['MINIO_MEANS_BUCKET_NAME'],
current_app.config['MINIO_MODEL_BUCKET_NAME']]:
try:
storage.connection.make_bucket(bucket_name)
except BucketAlreadyOwnedByYou as err:
pass
except BucketAlreadyExists as err:
pass
# Everything else should be raised
def index(self): def index(self):
models = AIModel.query.all() models = AIModel.query.filter_by(type=AIModelType.SVM).all()
return jsonify(self.aimodels_schema.dump(models)), 200 return jsonify(self.aimodels_schema.dump(models)), 200
def post(self): def post(self):
@ -60,7 +46,7 @@ class ModelView(FlaskView):
abort(411, f"Content length for meansFile is not a positive integer or missing.") abort(411, f"Content length for meansFile is not a positive integer or missing.")
# create bucket if necessary # create bucket if necessary
self._ensure_buckets() ensure_buckets()
# Temporarily save the file, because pyAudioAnalysis can only read files # Temporarily save the file, because pyAudioAnalysis can only read files
_, temp_model_filename = tempfile.mkstemp() _, temp_model_filename = tempfile.mkstemp()
@ -71,23 +57,19 @@ class ModelView(FlaskView):
try: try:
if info['type'] == 'knn': _, _, _, _, mid_window, mid_step, short_window, short_step, compute_beat \
_, _, _, _, mid_window, mid_step, short_window, short_step, compute_beat \ = load_model(temp_model_filename)
= load_model_knn(temp_model_filename)
else:
_, _, _, _, mid_window, mid_step, short_window, short_step, compute_beat \
= load_model(temp_model_filename)
# Because of pyAudiomeme the files already saved, so we just use the file uploader functions # Because of pyAudiomeme the files already saved, so we just use the file uploader functions
storage.connection.fput_object( storage.connection.fput_object(
current_app.config['MINIO_MODEL_BUCKET_NAME'], current_app.config['MINIO_SVM_BUCKET_NAME'],
str(info['id']), "model/" + str(info['id']),
temp_model_filename temp_model_filename
) )
storage.connection.fput_object( storage.connection.fput_object(
current_app.config['MINIO_MEANS_BUCKET_NAME'], current_app.config['MINIO_SVM_BUCKET_NAME'],
str(info['id']), "means/" + str(info['id']),
temp_means_filename temp_means_filename
) )
@ -95,10 +77,19 @@ class ModelView(FlaskView):
os.remove(temp_model_filename) os.remove(temp_model_filename)
os.remove(temp_means_filename) os.remove(temp_means_filename)
m = AIModel(id=info['id'], mid_window=mid_window, mid_step=mid_step, short_window=short_window, m = AIModel(id=info['id'], type=AIModelType.SVM)
short_step=short_step, compute_beat=compute_beat, type=info['type'])
d = SVMDetails(
aimodel=m,
mid_window=mid_window,
mid_step=mid_step,
short_window=short_window,
short_step=short_step,
compute_beat=compute_beat
)
db.session.add(m) db.session.add(m)
db.session.add(d)
db.session.commit() db.session.commit()
return jsonify(self.aimodel_schema.dump(m)), 200 return jsonify(self.aimodel_schema.dump(m)), 200
@ -106,18 +97,19 @@ class ModelView(FlaskView):
def get(self, _id: str): def get(self, _id: str):
if _id == "$default": if _id == "$default":
default = Default.query.first_or_404() # TODO: Kitalálni, hogy inkább a latestestest-el térjen-e vissza # TODO: Kitalálni, hogy inkább a latestestest-el térjen-e vissza
default = Default.query.filter_by(type=AIModelType.SVM).first_or_404()
m = default.aimodel m = default.aimodel
else: else:
m = AIModel.query.filter_by(id=_id).first_or_404() m = AIModel.query.filter_by(type=AIModelType.SVM, id=_id).first_or_404()
if "means" in request.args: if "means" in request.args:
bucket = current_app.config['MINIO_MEANS_BUCKET_NAME'] path = "means/" + str(m.id)
else: else:
bucket = current_app.config['MINIO_MODEL_BUCKET_NAME'] path = "model/" + str(m.id)
try: try:
data = storage.connection.get_object(bucket, str(m.id)) data = storage.connection.get_object(current_app.config['MINIO_SVM_BUCKET_NAME'], path)
except NoSuchKey: except NoSuchKey:
abort(500, "The ID is stored in the database but not int the Object Store") abort(500, "The ID is stored in the database but not int the Object Store")
@ -127,23 +119,25 @@ class ModelView(FlaskView):
def get_details(self, _id: str): def get_details(self, _id: str):
if _id == "$default": if _id == "$default":
default = Default.query.first_or_404() # TODO: Kitalálni, hogy inkább a latestestest-el térjen-e vissza # TODO: Kitalálni, hogy inkább a latestestest-el térjen-e vissza
default = Default.query.filter_by(type=AIModelType.SVM).first_or_404()
m = default.aimodel m = default.aimodel
else: else:
m = AIModel.query.filter_by(id=_id).first_or_404() m = AIModel.query.filter_by(type=AIModelType.SVM, id=_id).first_or_404()
return jsonify(self.aimodel_schema.dump(m)) return jsonify(self.aimodel_schema.dump(m))
def delete(self, _id: str): def delete(self, _id: str):
if _id == '$default': if _id == "$default":
default = Default.query.first_or_404() # TODO: Kitalálni, hogy inkább a latestestest-el térjen-e vissza
default = Default.query.filter_by(type=AIModelType.SVM).first_or_404()
m = default.aimodel m = default.aimodel
else: else:
m = AIModel.query.filter_by(id=_id).first_or_404() m = AIModel.query.filter_by(type=AIModelType.SVM, id=_id).first_or_404()
storage.connection.remove_object(current_app.config['MINIO_MODEL_BUCKET_NAME'], str(m.id)) storage.connection.remove_object(current_app.config['MINIO_SVM_BUCKET_NAME'], "means/" + str(m.id))
storage.connection.remove_object(current_app.config['MINIO_MEANS_BUCKET_NAME'], str(m.id)) storage.connection.remove_object(current_app.config['MINIO_SVM_BUCKET_NAME'], "model/" + str(m.id))
db.session.delete(m) db.session.delete(m)
db.session.commit() db.session.commit()
@ -157,12 +151,12 @@ class ModelView(FlaskView):
try: try:
req = self.default_schema.load(request.json) req = self.default_schema.load(request.json)
except ValidationError as e: except ValidationError as e:
abort(404, str(e)) abort(400, str(e))
m = AIModel.query.filter_by(id=req['id']).first_or_404() m = AIModel.query.filter_by(type=AIModelType.SVM, id=req['id']).first_or_404()
Default.query.delete() Default.query.filter_by(type=AIModelType.SVM).delete()
new_default = Default(aimodel=m) new_default = Default(type=AIModelType.SVM, aimodel=m)
db.session.add(new_default) db.session.add(new_default)
db.session.commit() db.session.commit()

View File

@ -7,6 +7,7 @@ Flask-SQLAlchemy
SQLAlchemy-Utils SQLAlchemy-Utils
SQLAlchemy SQLAlchemy
marshmallow-sqlalchemy marshmallow-sqlalchemy
marshmallow-enum
psycopg2-binary psycopg2-binary
flask_minio flask_minio
sentry-sdk sentry-sdk