added bootstrap script

This commit is contained in:
Pünkösd Marcell 2021-07-27 18:51:22 +02:00
parent c9a81e1595
commit 4d352c3287

204
bootstrap.py Normal file
View File

@ -0,0 +1,204 @@
#!/usr/bin/env python3
from typing import Optional
import requests
import os.path
import sys
import json
from urllib.parse import urljoin
def compile_multipart_file_part(path: str) -> tuple:
return (
os.path.basename(path),
open(path, 'rb').read(),
'application/octet-stream',
{'Content-length': os.path.getsize(path)}
)
def compile_multipart_json_part(data: dict) -> tuple:
return (
None,
json.dumps(data),
"application/json"
)
def perform_upload_and_print_result_and_get_uuid(url: str, files: dict) -> Optional[str]:
try:
r = requests.post(url, files=files, timeout=90)
except requests.exceptions.RequestException as e:
print(" Failed!", flush=True)
print(e)
return None
if r.status_code != 200:
print(" Failed!", flush=True)
print("STATUS:", r.status_code)
print("HEADERS: ", r.headers)
print("CONTENT: ", r.content)
return None
else:
response_data = r.json()
if "id" not in response_data:
print(" Failed!", flush=True)
print("Invalid response: no id field!")
print("STATUS:", r.status_code)
print("HEADERS: ", r.headers)
print("CONTENT: ", r.content)
return None
print(" Success!", flush=True)
return response_data["id"]
def put_json_and_print_result(url: str, data: dict) -> bool:
try:
r = requests.put(url, json=data)
except requests.exceptions.RequestException as e:
print(" Failed!", flush=True)
print(e)
return False
if r.status_code != 204:
print(" Failed!", flush=True)
print("STATUS:", r.status_code)
print("HEADERS: ", r.headers)
print("CONTENT: ", r.content)
return False
else:
print(" Success!", flush=True)
return True
def get_and_print_failure_only_and_return_response(url: str) -> Optional[dict]:
try:
r = requests.get(url)
except requests.exceptions.RequestException as e:
print(" Failed!", flush=True)
print(e)
return None
if r.status_code != 200:
print(" Failed!", flush=True)
print("STATUS:", r.status_code)
print("HEADERS: ", r.headers)
print("CONTENT: ", r.content)
return None
else:
response_data = r.json()
return response_data
def main():
if len(sys.argv) != 2:
print("Usage: bootstrap.py [API_BASE]")
return
api_base = sys.argv[1]
basepath = os.path.dirname(os.path.abspath(__file__))
print(f"Bootstrapping Birbnetes deployment at {api_base} with models in {basepath}...")
print("[1/5] Uploading CNN model...", end="", flush=True)
# Upload CNN first
cnn_modelFile = os.path.join(basepath, "models/cnn/model_batch_590.json")
cnn_weightsFile = os.path.join(basepath, "models/cnn/best_model_batch_590.h5")
files = {
"modelFile": compile_multipart_file_part(cnn_modelFile),
"weightsFile": compile_multipart_file_part(cnn_weightsFile),
"info": compile_multipart_json_part({"target_class_name": "sturnus"})
}
cnn_uuid = perform_upload_and_print_result_and_get_uuid(urljoin(api_base, "model/cnn"), files)
if not cnn_uuid:
return
print("[2/5] Uploading SVM model...", end="", flush=True)
# Upload SVM model
svm_modelFile = os.path.join(basepath, "models/svm/svm_8_500")
svm_meansFile = os.path.join(basepath, "models/svm/svm_8_500MEANS")
files = {
"modelFile": compile_multipart_file_part(svm_modelFile),
"meansFile": compile_multipart_file_part(svm_meansFile),
"info": compile_multipart_json_part({"target_class_name": "Chirp"})
}
svm_uuid = perform_upload_and_print_result_and_get_uuid(urljoin(api_base, "model/svm"), files)
if not svm_uuid:
return
print("[3/5] Setting default CNN model...", end="", flush=True)
if not put_json_and_print_result(urljoin(api_base, "model/cnn/$default"), {"id": cnn_uuid}):
return
print("[4/5] Setting default SVM model...", end="", flush=True)
if not put_json_and_print_result(urljoin(api_base, "model/svm/$default"), {"id": svm_uuid}):
return
print("[5/5] Validating...", end="", flush=True)
data = get_and_print_failure_only_and_return_response(urljoin(api_base, "model"))
if not data:
return
svm_found = False
cnn_found = False
for model_data in data:
if model_data['id'] == cnn_uuid:
if not model_data['default']:
print(" Failed!", flush=True)
print("The uploaded CNN model is not the default")
print("DATA:", data)
return
else:
if cnn_found:
print(" Failed!", flush=True)
print("The uploaded CNN model appears twice")
print("DATA:", data)
return
else:
cnn_found = True
if model_data['id'] == svm_uuid:
if not model_data['default']:
print(" Failed!", flush=True)
print("The uploaded SVM model is not the default")
print("DATA:", data)
return
else:
if svm_found:
print(" Failed!", flush=True)
print("The uploaded SVM model appears twice")
print("DATA:", data)
return
else:
svm_found = True
if not cnn_found:
print(" Failed!", flush=True)
print("The uploaded CNN model is missing")
print("EXPETED:", cnn_uuid)
print("DATA:", data)
return
if not svm_found:
print(" Failed!", flush=True)
print("The uploaded SVM model is missing")
print("EXPETED:", svm_uuid)
print("DATA:", data)
return
print(" Success!", flush=True)
print("Your Birbnetes deployment is ready!")
if __name__ == '__main__':
main()