model-service/model_service/views/root_view.py

114 lines
3.7 KiB
Python

#!/usr/bin/env python3
from flask import jsonify, abort, request, url_for
from flask_classful import FlaskView, route
from marshmallow import ValidationError
from utils import json_required
from model import db, AIModel, AIModelType, Default
from schemas import AIModelSchema, DefaultSchema
import opentracing
class RootView(FlaskView):
route_base = '/'
aimodels_schema = AIModelSchema(many=True, exclude=['timestamp', 'details', 'target_class_name'])
aimodel_schema = AIModelSchema(many=False)
default_schema = DefaultSchema(many=False)
## Shared stuff goes here
def index(self):
with opentracing.tracer.start_active_span('sqlalchemy.select'):
models = AIModel.query.all()
return jsonify(self.aimodels_schema.dump(models))
@route('/<type_>')
def get_models(self, type_: str):
try:
aimodel_type = AIModelType[type_]
except KeyError:
return abort(404, "Unknown type")
with opentracing.tracer.start_active_span('sqlalchemy.select', tags={"aimodel_type": aimodel_type}):
models = AIModel.query.filter_by(type=aimodel_type).all()
return jsonify(self.aimodels_schema.dump(models)), 200
@route('/<type_>/<id_>')
def get_model(self, type_: str, id_: str):
try:
aimodel_type = AIModelType[type_]
except KeyError:
return abort(404, "Unknown type")
with opentracing.tracer.start_active_span(
'sqlalchemy.select',
tags={"aimodel_type": aimodel_type, "id": id_}
):
if id_ == "$default":
default = Default.query.filter_by(type=aimodel_type).first_or_404()
m = default.aimodel
else:
m = AIModel.query.filter_by(type=aimodel_type, id=id_).first_or_404()
# Append download links
with opentracing.tracer.start_active_span(
'compileResponseDict',
tags={"id": id_}
):
details = self.aimodel_schema.dump(m)
# Vagy ez, vagy visszateszem a saját view-jébe és duplikálva lesz az egész
if aimodel_type == AIModelType.cnn:
details.update({
"files": {
"model": url_for("CNNView:get_file", id_=m.id),
"weights": url_for("CNNView:get_file", id_=m.id, weights=''),
}
})
elif aimodel_type == AIModelType.svm:
details.update({
"files": {
"model": url_for("SVMView:get_file", id_=m.id),
"means": url_for("SVMView:get_file", id_=m.id, means=''),
}
})
return jsonify(details)
@json_required
@route('/<type_>/$default', methods=['PUT'])
def put_default(self, type_: str):
try:
aimodel_type = AIModelType[type_]
except KeyError:
return abort(404, "Unknown type")
try:
req = self.default_schema.load(request.json)
except ValidationError as e:
return abort(400, str(e))
with opentracing.tracer.start_active_span('sqlalchemy.select'):
m = AIModel.query.filter_by(type=aimodel_type, id=req['id']).first_or_404()
with opentracing.tracer.start_active_span('sqlalchemy.delete'):
Default.query.filter_by(type=aimodel_type).delete()
with opentracing.tracer.start_active_span('sqlalchemy.create'):
new_default = Default(type=aimodel_type, aimodel=m)
db.session.add(new_default)
with opentracing.tracer.start_active_span('sqlalchemy.commit'):
db.session.commit()
return '', 204