added bootstrap script
This commit is contained in:
parent
c9a81e1595
commit
4d352c3287
204
bootstrap.py
Normal file
204
bootstrap.py
Normal file
@ -0,0 +1,204 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
from typing import Optional
|
||||||
|
import requests
|
||||||
|
import os.path
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
|
|
||||||
|
def compile_multipart_file_part(path: str) -> tuple:
|
||||||
|
return (
|
||||||
|
os.path.basename(path),
|
||||||
|
open(path, 'rb').read(),
|
||||||
|
'application/octet-stream',
|
||||||
|
{'Content-length': os.path.getsize(path)}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def compile_multipart_json_part(data: dict) -> tuple:
|
||||||
|
return (
|
||||||
|
None,
|
||||||
|
json.dumps(data),
|
||||||
|
"application/json"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def perform_upload_and_print_result_and_get_uuid(url: str, files: dict) -> Optional[str]:
|
||||||
|
try:
|
||||||
|
r = requests.post(url, files=files, timeout=90)
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
print(" Failed!", flush=True)
|
||||||
|
print(e)
|
||||||
|
return None
|
||||||
|
|
||||||
|
if r.status_code != 200:
|
||||||
|
print(" Failed!", flush=True)
|
||||||
|
print("STATUS:", r.status_code)
|
||||||
|
print("HEADERS: ", r.headers)
|
||||||
|
print("CONTENT: ", r.content)
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
response_data = r.json()
|
||||||
|
if "id" not in response_data:
|
||||||
|
print(" Failed!", flush=True)
|
||||||
|
print("Invalid response: no id field!")
|
||||||
|
print("STATUS:", r.status_code)
|
||||||
|
print("HEADERS: ", r.headers)
|
||||||
|
print("CONTENT: ", r.content)
|
||||||
|
return None
|
||||||
|
|
||||||
|
print(" Success!", flush=True)
|
||||||
|
return response_data["id"]
|
||||||
|
|
||||||
|
|
||||||
|
def put_json_and_print_result(url: str, data: dict) -> bool:
|
||||||
|
try:
|
||||||
|
r = requests.put(url, json=data)
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
print(" Failed!", flush=True)
|
||||||
|
print(e)
|
||||||
|
return False
|
||||||
|
|
||||||
|
if r.status_code != 204:
|
||||||
|
print(" Failed!", flush=True)
|
||||||
|
print("STATUS:", r.status_code)
|
||||||
|
print("HEADERS: ", r.headers)
|
||||||
|
print("CONTENT: ", r.content)
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
print(" Success!", flush=True)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def get_and_print_failure_only_and_return_response(url: str) -> Optional[dict]:
|
||||||
|
try:
|
||||||
|
r = requests.get(url)
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
print(" Failed!", flush=True)
|
||||||
|
print(e)
|
||||||
|
return None
|
||||||
|
|
||||||
|
if r.status_code != 200:
|
||||||
|
print(" Failed!", flush=True)
|
||||||
|
print("STATUS:", r.status_code)
|
||||||
|
print("HEADERS: ", r.headers)
|
||||||
|
print("CONTENT: ", r.content)
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
response_data = r.json()
|
||||||
|
return response_data
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
if len(sys.argv) != 2:
|
||||||
|
print("Usage: bootstrap.py [API_BASE]")
|
||||||
|
return
|
||||||
|
|
||||||
|
api_base = sys.argv[1]
|
||||||
|
basepath = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
print(f"Bootstrapping Birbnetes deployment at {api_base} with models in {basepath}...")
|
||||||
|
|
||||||
|
print("[1/5] Uploading CNN model...", end="", flush=True)
|
||||||
|
|
||||||
|
# Upload CNN first
|
||||||
|
cnn_modelFile = os.path.join(basepath, "models/cnn/model_batch_590.json")
|
||||||
|
cnn_weightsFile = os.path.join(basepath, "models/cnn/best_model_batch_590.h5")
|
||||||
|
|
||||||
|
files = {
|
||||||
|
"modelFile": compile_multipart_file_part(cnn_modelFile),
|
||||||
|
"weightsFile": compile_multipart_file_part(cnn_weightsFile),
|
||||||
|
"info": compile_multipart_json_part({"target_class_name": "sturnus"})
|
||||||
|
}
|
||||||
|
|
||||||
|
cnn_uuid = perform_upload_and_print_result_and_get_uuid(urljoin(api_base, "model/cnn"), files)
|
||||||
|
|
||||||
|
if not cnn_uuid:
|
||||||
|
return
|
||||||
|
|
||||||
|
print("[2/5] Uploading SVM model...", end="", flush=True)
|
||||||
|
|
||||||
|
# Upload SVM model
|
||||||
|
svm_modelFile = os.path.join(basepath, "models/svm/svm_8_500")
|
||||||
|
svm_meansFile = os.path.join(basepath, "models/svm/svm_8_500MEANS")
|
||||||
|
|
||||||
|
files = {
|
||||||
|
"modelFile": compile_multipart_file_part(svm_modelFile),
|
||||||
|
"meansFile": compile_multipart_file_part(svm_meansFile),
|
||||||
|
"info": compile_multipart_json_part({"target_class_name": "Chirp"})
|
||||||
|
}
|
||||||
|
|
||||||
|
svm_uuid = perform_upload_and_print_result_and_get_uuid(urljoin(api_base, "model/svm"), files)
|
||||||
|
|
||||||
|
if not svm_uuid:
|
||||||
|
return
|
||||||
|
|
||||||
|
print("[3/5] Setting default CNN model...", end="", flush=True)
|
||||||
|
if not put_json_and_print_result(urljoin(api_base, "model/cnn/$default"), {"id": cnn_uuid}):
|
||||||
|
return
|
||||||
|
|
||||||
|
print("[4/5] Setting default SVM model...", end="", flush=True)
|
||||||
|
if not put_json_and_print_result(urljoin(api_base, "model/svm/$default"), {"id": svm_uuid}):
|
||||||
|
return
|
||||||
|
|
||||||
|
print("[5/5] Validating...", end="", flush=True)
|
||||||
|
data = get_and_print_failure_only_and_return_response(urljoin(api_base, "model"))
|
||||||
|
|
||||||
|
if not data:
|
||||||
|
return
|
||||||
|
|
||||||
|
svm_found = False
|
||||||
|
cnn_found = False
|
||||||
|
for model_data in data:
|
||||||
|
if model_data['id'] == cnn_uuid:
|
||||||
|
if not model_data['default']:
|
||||||
|
print(" Failed!", flush=True)
|
||||||
|
print("The uploaded CNN model is not the default")
|
||||||
|
print("DATA:", data)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
if cnn_found:
|
||||||
|
print(" Failed!", flush=True)
|
||||||
|
print("The uploaded CNN model appears twice")
|
||||||
|
print("DATA:", data)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
cnn_found = True
|
||||||
|
|
||||||
|
if model_data['id'] == svm_uuid:
|
||||||
|
if not model_data['default']:
|
||||||
|
print(" Failed!", flush=True)
|
||||||
|
print("The uploaded SVM model is not the default")
|
||||||
|
print("DATA:", data)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
if svm_found:
|
||||||
|
print(" Failed!", flush=True)
|
||||||
|
print("The uploaded SVM model appears twice")
|
||||||
|
print("DATA:", data)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
svm_found = True
|
||||||
|
|
||||||
|
if not cnn_found:
|
||||||
|
print(" Failed!", flush=True)
|
||||||
|
print("The uploaded CNN model is missing")
|
||||||
|
print("EXPETED:", cnn_uuid)
|
||||||
|
print("DATA:", data)
|
||||||
|
return
|
||||||
|
|
||||||
|
if not svm_found:
|
||||||
|
print(" Failed!", flush=True)
|
||||||
|
print("The uploaded SVM model is missing")
|
||||||
|
print("EXPETED:", svm_uuid)
|
||||||
|
print("DATA:", data)
|
||||||
|
return
|
||||||
|
|
||||||
|
print(" Success!", flush=True)
|
||||||
|
|
||||||
|
print("Your Birbnetes deployment is ready!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
Loading…
Reference in New Issue
Block a user