Added ability to store CNN models
continuous-integration/drone/push Build is passing Details

This commit is contained in:
Pünkösd Marcell 2020-07-28 20:19:52 +02:00
parent fe965706b5
commit 284dfce735
12 changed files with 224 additions and 67 deletions

View File

@ -11,7 +11,7 @@ from model import db
from utils import register_all_error_handlers, storage
# import views
from views import ModelView
from views import SVMView, CNNView
# Setup sentry
SENTRY_DSN = os.environ.get("SENTRY_DSN")
@ -32,8 +32,8 @@ app.config['SQLALCHEMY_DATABASE_URI'] = os.environ.get('DATABASE_URI', "sqlite:/
app.config['MINIO_ENDPOINT'] = os.environ['MINIO_ENDPOINT']
app.config['MINIO_ACCESS_KEY'] = os.environ['MINIO_ACCESS_KEY']
app.config['MINIO_SECRET_KEY'] = os.environ['MINIO_SECRET_KEY']
app.config['MINIO_MODEL_BUCKET_NAME'] = os.environ['MINIO_MODEL_BUCKET_NAME']
app.config['MINIO_MEANS_BUCKET_NAME'] = os.environ['MINIO_MEANS_BUCKET_NAME']
app.config['MINIO_SVM_BUCKET_NAME'] = os.environ['MINIO_SVM_BUCKET_NAME']
app.config['MINIO_CNN_BUCKET_NAME'] = os.environ['MINIO_CNN_BUCKET_NAME']
app.config['MINIO_SECURE'] = os.environ.get('MINIO_SECURE', False)
app.config['MINIO_REGION'] = os.environ.get('MINIO_REGION', None)
@ -53,8 +53,8 @@ with app.app_context():
register_all_error_handlers(app)
# register views
for view in [ModelView]:
view.register(app, trailing_slash=False)
for view in [SVMView, CNNView]:
view.register(app, trailing_slash=False, route_prefix='/model')
# start debuggig if needed
if __name__ == "__main__":

View File

@ -1,4 +1,5 @@
#!/usr/bin/env python3
from .db import db
from .default import Default
from .aimodel import AIModel
from .aimodel import AIModel, AIModelType
from .svmdetails import SVMDetails

View File

@ -3,16 +3,14 @@ from .db import db
from sqlalchemy.sql import func
from sqlalchemy.dialects.postgresql import UUID
import uuid
import enum
class AIModelType(enum.Enum):
SVM = 1
CNN = 2
class AIModel(db.Model):
id = db.Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, unique=True, nullable=False)
timestamp = db.Column(db.TIMESTAMP, nullable=False, server_default=func.now())
# details
mid_window = db.Column(db.Float)
mid_step = db.Column(db.Float)
short_window = db.Column(db.Float)
short_step = db.Column(db.Float)
compute_beat = db.Column(db.Boolean)
type = db.Column(db.String(15))
type = db.Column(db.Enum(AIModelType), nullable=False)

View File

@ -2,7 +2,11 @@
from .db import db
from sqlalchemy.dialects.postgresql import UUID
from .aimodel import AIModelType
class Default(db.Model):
aimodel_id = db.Column(UUID(as_uuid=True), db.ForeignKey("ai_model.id"), nullable=False, primary_key=True)
type = db.Column(db.Enum(AIModelType), unique=True, nullable=False, primary_key=True)
aimodel_id = db.Column(UUID(as_uuid=True), db.ForeignKey("ai_model.id"), nullable=False)
aimodel = db.relationship("AIModel", backref=db.backref("default", lazy=True, cascade="save-update, merge, delete, delete-orphan"))

View File

@ -0,0 +1,16 @@
#!/usr/bin/env python3
from .db import db
from sqlalchemy.dialects.postgresql import UUID
class SVMDetails(db.Model):
aimodel_id = db.Column(UUID(as_uuid=True), db.ForeignKey("ai_model.id"), nullable=False, primary_key=True)
aimodel = db.relationship("AIModel", backref=db.backref("details", lazy=True, cascade="save-update, merge, delete, delete-orphan"))
# details
mid_window = db.Column(db.Float)
mid_step = db.Column(db.Float)
short_window = db.Column(db.Float)
short_step = db.Column(db.Float)
compute_beat = db.Column(db.Boolean)

View File

@ -8,7 +8,6 @@ import uuid
class InfoSchema(Schema):
id = fields.UUID(default=uuid.uuid4, missing=uuid.uuid4)
type = fields.String(validate=OneOf(['svm', 'svm_rbf', 'knn', 'extratrees', 'gradientboosting', 'randomforest']))
@classmethod # This threats none values as missing
def get_attribute(cls, attr, obj, default):

View File

@ -1,4 +1,4 @@
#!/usr/bin/env python3
from .require_decorators import json_required
from .error_handlers import register_all_error_handlers
from .storage import storage
from .storage import storage, ensure_buckets

View File

@ -1,3 +1,18 @@
from flask import current_app
from flask_minio import Minio
from minio.error import BucketAlreadyExists, BucketAlreadyOwnedByYou
storage = Minio()
storage = Minio()
def ensure_buckets():
for bucket_name in [current_app.config['MINIO_SVM_BUCKET_NAME'],
current_app.config['MINIO_CNN_BUCKET_NAME']]:
try:
storage.connection.make_bucket(bucket_name)
except BucketAlreadyOwnedByYou:
pass
except BucketAlreadyExists:
pass
# Everything else should be raised

View File

@ -1,2 +1,3 @@
#!/usr/bin/env python3
from .model_view import ModelView
from .svm_view import SVMView
from .cnn_view import CNNView

View File

@ -0,0 +1,128 @@
#!/usr/bin/env python3
from flask import request, jsonify, current_app, abort, Response
from flask_classful import FlaskView, route
from model import db, Default, AIModel, AIModelType
from minio.error import NoSuchKey
from schemas import AIModelSchema, DefaultSchema, InfoSchema
from marshmallow.exceptions import ValidationError
from utils import json_required, storage, ensure_buckets
class CNNView(FlaskView):
aimodel_schema = AIModelSchema(many=False)
aimodels_schema = AIModelSchema(many=True, exclude=['timestamp', 'details'])
default_schema = DefaultSchema(many=False)
info_schema = InfoSchema(many=False)
def index(self):
models = AIModel.query.filter_by(type=AIModelType.CNN).all()
return jsonify(self.aimodels_schema.dump(models)), 200
def post(self):
# get important data from the request
try:
info = self.info_schema.loads(request.form.get('info'))
except ValidationError as e:
abort(400, str(e))
# check for conflict
m = AIModel.query.filter_by(id=info['id']).first()
if m:
abort(409)
# get and validate file
model_file = request.files['modelFile']
if model_file.content_length <= 0:
abort(411, f"Content length for modelFile is not a positive integer or missing.")
weights_file = request.files['weightsFile']
if weights_file.content_length <= 0:
abort(411, f"Content length for weightsFile is not a positive integer or missing.")
# create bucket if necessary
ensure_buckets()
# Put files into MinIO
storage.connection.put_object(current_app.config['MINIO_CNN_BUCKET_NAME'], "model/" + str(m.id), model_file,
model_file.content_length, content_type=model_file.content_type)
storage.connection.put_object(current_app.config['MINIO_CNN_BUCKET_NAME'], "weights/" + str(m.id), weights_file,
weights_file.content_length, content_type=weights_file.content_type)
m = AIModel(id=info['id'], type=AIModelType.CNN)
db.session.add(m)
db.session.commit()
return jsonify(self.aimodel_schema.dump(m)), 200
def get(self, _id: str):
if _id == "$default":
# TODO: Kitalálni, hogy inkább a latestestest-el térjen-e vissza
default = Default.query.filter_by(type=AIModelType.CNN).first_or_404()
m = default.aimodel
else:
m = AIModel.query.filter_by(type=AIModelType.CNN, id=_id).first_or_404()
if "weights" in request.args:
path = "weights/" + str(m.id)
else:
path = "model/" + str(m.id)
try:
data = storage.connection.get_object(current_app.config['MINIO_CNN_BUCKET_NAME'], path)
except NoSuchKey:
abort(500, "The ID is stored in the database but not int the Object Store")
return Response(data.stream(), mimetype=data.headers['Content-type'])
@route('<_id>/details')
def get_details(self, _id: str):
if _id == "$default":
# TODO: Kitalálni, hogy inkább a latestestest-el térjen-e vissza
default = Default.query.filter_by(type=AIModelType.CNN).first_or_404()
m = default.aimodel
else:
m = AIModel.query.filter_by(type=AIModelType.CNN, id=_id).first_or_404()
return jsonify(self.aimodel_schema.dump(m))
def delete(self, _id: str):
if _id == "$default":
# TODO: Kitalálni, hogy inkább a latestestest-el térjen-e vissza
default = Default.query.filter_by(type=AIModelType.CNN).first_or_404()
m = default.aimodel
else:
m = AIModel.query.filter_by(type=AIModelType.CNN, id=_id).first_or_404()
storage.connection.remove_object(current_app.config['MINIO_CNN_BUCKET_NAME'], "weights/" + str(m.id))
storage.connection.remove_object(current_app.config['MINIO_CNN_BUCKET_NAME'], "model/" + str(m.id))
db.session.delete(m)
db.session.commit()
return '', 204
@json_required
@route('$default', methods=['PUT'])
def put_default(self):
try:
req = self.default_schema.load(request.json)
except ValidationError as e:
abort(400, str(e))
m = AIModel.query.filter_by(type=AIModelType.CNN, id=req['id']).first_or_404()
Default.query.filter_by(type=AIModelType.CNN).delete()
new_default = Default(type=AIModelType.CNN, aimodel=m)
db.session.add(new_default)
db.session.commit()
return '', 204

View File

@ -3,36 +3,22 @@ import tempfile
import os
from flask import request, jsonify, current_app, abort, Response
from flask_classful import FlaskView, route
from model import db, Default, AIModel
from minio.error import BucketAlreadyExists, BucketAlreadyOwnedByYou, ResponseError, NoSuchKey
from model import db, Default, AIModel, AIModelType, SVMDetails
from minio.error import NoSuchKey
from schemas import AIModelSchema, DefaultSchema, InfoSchema
from marshmallow.exceptions import ValidationError
from utils import json_required, storage
from utils import json_required, storage, ensure_buckets
from pyAudioAnalysis.audioTrainTest import load_model, load_model_knn
class ModelView(FlaskView):
class SVMView(FlaskView):
aimodel_schema = AIModelSchema(many=False)
aimodels_schema = AIModelSchema(many=True,
exclude=['timestamp', 'mid_window', 'mid_step', 'short_window', 'short_step',
'compute_beat'])
aimodels_schema = AIModelSchema(many=True, exclude=['timestamp', 'details'])
default_schema = DefaultSchema(many=False)
info_schema = InfoSchema(many=False)
def _ensure_buckets(self):
for bucket_name in [current_app.config['MINIO_MEANS_BUCKET_NAME'],
current_app.config['MINIO_MODEL_BUCKET_NAME']]:
try:
storage.connection.make_bucket(bucket_name)
except BucketAlreadyOwnedByYou as err:
pass
except BucketAlreadyExists as err:
pass
# Everything else should be raised
def index(self):
models = AIModel.query.all()
models = AIModel.query.filter_by(type=AIModelType.SVM).all()
return jsonify(self.aimodels_schema.dump(models)), 200
def post(self):
@ -60,7 +46,7 @@ class ModelView(FlaskView):
abort(411, f"Content length for meansFile is not a positive integer or missing.")
# create bucket if necessary
self._ensure_buckets()
ensure_buckets()
# Temporarily save the file, because pyAudioAnalysis can only read files
_, temp_model_filename = tempfile.mkstemp()
@ -71,23 +57,19 @@ class ModelView(FlaskView):
try:
if info['type'] == 'knn':
_, _, _, _, mid_window, mid_step, short_window, short_step, compute_beat \
= load_model_knn(temp_model_filename)
else:
_, _, _, _, mid_window, mid_step, short_window, short_step, compute_beat \
= load_model(temp_model_filename)
_, _, _, _, mid_window, mid_step, short_window, short_step, compute_beat \
= load_model(temp_model_filename)
# Because of pyAudiomeme the files already saved, so we just use the file uploader functions
storage.connection.fput_object(
current_app.config['MINIO_MODEL_BUCKET_NAME'],
str(info['id']),
current_app.config['MINIO_SVM_BUCKET_NAME'],
"model/" + str(info['id']),
temp_model_filename
)
storage.connection.fput_object(
current_app.config['MINIO_MEANS_BUCKET_NAME'],
str(info['id']),
current_app.config['MINIO_SVM_BUCKET_NAME'],
"means/" + str(info['id']),
temp_means_filename
)
@ -95,10 +77,19 @@ class ModelView(FlaskView):
os.remove(temp_model_filename)
os.remove(temp_means_filename)
m = AIModel(id=info['id'], mid_window=mid_window, mid_step=mid_step, short_window=short_window,
short_step=short_step, compute_beat=compute_beat, type=info['type'])
m = AIModel(id=info['id'], type=AIModelType.SVM)
d = SVMDetails(
aimodel=m,
mid_window=mid_window,
mid_step=mid_step,
short_window=short_window,
short_step=short_step,
compute_beat=compute_beat
)
db.session.add(m)
db.session.add(d)
db.session.commit()
return jsonify(self.aimodel_schema.dump(m)), 200
@ -106,18 +97,19 @@ class ModelView(FlaskView):
def get(self, _id: str):
if _id == "$default":
default = Default.query.first_or_404() # TODO: Kitalálni, hogy inkább a latestestest-el térjen-e vissza
# TODO: Kitalálni, hogy inkább a latestestest-el térjen-e vissza
default = Default.query.filter_by(type=AIModelType.SVM).first_or_404()
m = default.aimodel
else:
m = AIModel.query.filter_by(id=_id).first_or_404()
m = AIModel.query.filter_by(type=AIModelType.SVM, id=_id).first_or_404()
if "means" in request.args:
bucket = current_app.config['MINIO_MEANS_BUCKET_NAME']
path = "means/" + str(m.id)
else:
bucket = current_app.config['MINIO_MODEL_BUCKET_NAME']
path = "model/" + str(m.id)
try:
data = storage.connection.get_object(bucket, str(m.id))
data = storage.connection.get_object(current_app.config['MINIO_SVM_BUCKET_NAME'], path)
except NoSuchKey:
abort(500, "The ID is stored in the database but not int the Object Store")
@ -127,23 +119,25 @@ class ModelView(FlaskView):
def get_details(self, _id: str):
if _id == "$default":
default = Default.query.first_or_404() # TODO: Kitalálni, hogy inkább a latestestest-el térjen-e vissza
# TODO: Kitalálni, hogy inkább a latestestest-el térjen-e vissza
default = Default.query.filter_by(type=AIModelType.SVM).first_or_404()
m = default.aimodel
else:
m = AIModel.query.filter_by(id=_id).first_or_404()
m = AIModel.query.filter_by(type=AIModelType.SVM, id=_id).first_or_404()
return jsonify(self.aimodel_schema.dump(m))
def delete(self, _id: str):
if _id == '$default':
default = Default.query.first_or_404()
if _id == "$default":
# TODO: Kitalálni, hogy inkább a latestestest-el térjen-e vissza
default = Default.query.filter_by(type=AIModelType.SVM).first_or_404()
m = default.aimodel
else:
m = AIModel.query.filter_by(id=_id).first_or_404()
m = AIModel.query.filter_by(type=AIModelType.SVM, id=_id).first_or_404()
storage.connection.remove_object(current_app.config['MINIO_MODEL_BUCKET_NAME'], str(m.id))
storage.connection.remove_object(current_app.config['MINIO_MEANS_BUCKET_NAME'], str(m.id))
storage.connection.remove_object(current_app.config['MINIO_SVM_BUCKET_NAME'], "means/" + str(m.id))
storage.connection.remove_object(current_app.config['MINIO_SVM_BUCKET_NAME'], "model/" + str(m.id))
db.session.delete(m)
db.session.commit()
@ -157,12 +151,12 @@ class ModelView(FlaskView):
try:
req = self.default_schema.load(request.json)
except ValidationError as e:
abort(404, str(e))
abort(400, str(e))
m = AIModel.query.filter_by(id=req['id']).first_or_404()
m = AIModel.query.filter_by(type=AIModelType.SVM, id=req['id']).first_or_404()
Default.query.delete()
new_default = Default(aimodel=m)
Default.query.filter_by(type=AIModelType.SVM).delete()
new_default = Default(type=AIModelType.SVM, aimodel=m)
db.session.add(new_default)
db.session.commit()

View File

@ -7,6 +7,7 @@ Flask-SQLAlchemy
SQLAlchemy-Utils
SQLAlchemy
marshmallow-sqlalchemy
marshmallow-enum
psycopg2-binary
flask_minio
sentry-sdk