added bootstrap script
This commit is contained in:
parent
c9a81e1595
commit
4d352c3287
204
bootstrap.py
Normal file
204
bootstrap.py
Normal file
@ -0,0 +1,204 @@
|
||||
#!/usr/bin/env python3
|
||||
from typing import Optional
|
||||
import requests
|
||||
import os.path
|
||||
import sys
|
||||
import json
|
||||
from urllib.parse import urljoin
|
||||
|
||||
|
||||
def compile_multipart_file_part(path: str) -> tuple:
|
||||
return (
|
||||
os.path.basename(path),
|
||||
open(path, 'rb').read(),
|
||||
'application/octet-stream',
|
||||
{'Content-length': os.path.getsize(path)}
|
||||
)
|
||||
|
||||
|
||||
def compile_multipart_json_part(data: dict) -> tuple:
|
||||
return (
|
||||
None,
|
||||
json.dumps(data),
|
||||
"application/json"
|
||||
)
|
||||
|
||||
|
||||
def perform_upload_and_print_result_and_get_uuid(url: str, files: dict) -> Optional[str]:
|
||||
try:
|
||||
r = requests.post(url, files=files, timeout=90)
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(" Failed!", flush=True)
|
||||
print(e)
|
||||
return None
|
||||
|
||||
if r.status_code != 200:
|
||||
print(" Failed!", flush=True)
|
||||
print("STATUS:", r.status_code)
|
||||
print("HEADERS: ", r.headers)
|
||||
print("CONTENT: ", r.content)
|
||||
return None
|
||||
else:
|
||||
response_data = r.json()
|
||||
if "id" not in response_data:
|
||||
print(" Failed!", flush=True)
|
||||
print("Invalid response: no id field!")
|
||||
print("STATUS:", r.status_code)
|
||||
print("HEADERS: ", r.headers)
|
||||
print("CONTENT: ", r.content)
|
||||
return None
|
||||
|
||||
print(" Success!", flush=True)
|
||||
return response_data["id"]
|
||||
|
||||
|
||||
def put_json_and_print_result(url: str, data: dict) -> bool:
|
||||
try:
|
||||
r = requests.put(url, json=data)
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(" Failed!", flush=True)
|
||||
print(e)
|
||||
return False
|
||||
|
||||
if r.status_code != 204:
|
||||
print(" Failed!", flush=True)
|
||||
print("STATUS:", r.status_code)
|
||||
print("HEADERS: ", r.headers)
|
||||
print("CONTENT: ", r.content)
|
||||
return False
|
||||
else:
|
||||
print(" Success!", flush=True)
|
||||
return True
|
||||
|
||||
|
||||
def get_and_print_failure_only_and_return_response(url: str) -> Optional[dict]:
|
||||
try:
|
||||
r = requests.get(url)
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(" Failed!", flush=True)
|
||||
print(e)
|
||||
return None
|
||||
|
||||
if r.status_code != 200:
|
||||
print(" Failed!", flush=True)
|
||||
print("STATUS:", r.status_code)
|
||||
print("HEADERS: ", r.headers)
|
||||
print("CONTENT: ", r.content)
|
||||
return None
|
||||
else:
|
||||
response_data = r.json()
|
||||
return response_data
|
||||
|
||||
|
||||
def main():
|
||||
if len(sys.argv) != 2:
|
||||
print("Usage: bootstrap.py [API_BASE]")
|
||||
return
|
||||
|
||||
api_base = sys.argv[1]
|
||||
basepath = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
print(f"Bootstrapping Birbnetes deployment at {api_base} with models in {basepath}...")
|
||||
|
||||
print("[1/5] Uploading CNN model...", end="", flush=True)
|
||||
|
||||
# Upload CNN first
|
||||
cnn_modelFile = os.path.join(basepath, "models/cnn/model_batch_590.json")
|
||||
cnn_weightsFile = os.path.join(basepath, "models/cnn/best_model_batch_590.h5")
|
||||
|
||||
files = {
|
||||
"modelFile": compile_multipart_file_part(cnn_modelFile),
|
||||
"weightsFile": compile_multipart_file_part(cnn_weightsFile),
|
||||
"info": compile_multipart_json_part({"target_class_name": "sturnus"})
|
||||
}
|
||||
|
||||
cnn_uuid = perform_upload_and_print_result_and_get_uuid(urljoin(api_base, "model/cnn"), files)
|
||||
|
||||
if not cnn_uuid:
|
||||
return
|
||||
|
||||
print("[2/5] Uploading SVM model...", end="", flush=True)
|
||||
|
||||
# Upload SVM model
|
||||
svm_modelFile = os.path.join(basepath, "models/svm/svm_8_500")
|
||||
svm_meansFile = os.path.join(basepath, "models/svm/svm_8_500MEANS")
|
||||
|
||||
files = {
|
||||
"modelFile": compile_multipart_file_part(svm_modelFile),
|
||||
"meansFile": compile_multipart_file_part(svm_meansFile),
|
||||
"info": compile_multipart_json_part({"target_class_name": "Chirp"})
|
||||
}
|
||||
|
||||
svm_uuid = perform_upload_and_print_result_and_get_uuid(urljoin(api_base, "model/svm"), files)
|
||||
|
||||
if not svm_uuid:
|
||||
return
|
||||
|
||||
print("[3/5] Setting default CNN model...", end="", flush=True)
|
||||
if not put_json_and_print_result(urljoin(api_base, "model/cnn/$default"), {"id": cnn_uuid}):
|
||||
return
|
||||
|
||||
print("[4/5] Setting default SVM model...", end="", flush=True)
|
||||
if not put_json_and_print_result(urljoin(api_base, "model/svm/$default"), {"id": svm_uuid}):
|
||||
return
|
||||
|
||||
print("[5/5] Validating...", end="", flush=True)
|
||||
data = get_and_print_failure_only_and_return_response(urljoin(api_base, "model"))
|
||||
|
||||
if not data:
|
||||
return
|
||||
|
||||
svm_found = False
|
||||
cnn_found = False
|
||||
for model_data in data:
|
||||
if model_data['id'] == cnn_uuid:
|
||||
if not model_data['default']:
|
||||
print(" Failed!", flush=True)
|
||||
print("The uploaded CNN model is not the default")
|
||||
print("DATA:", data)
|
||||
return
|
||||
else:
|
||||
if cnn_found:
|
||||
print(" Failed!", flush=True)
|
||||
print("The uploaded CNN model appears twice")
|
||||
print("DATA:", data)
|
||||
return
|
||||
else:
|
||||
cnn_found = True
|
||||
|
||||
if model_data['id'] == svm_uuid:
|
||||
if not model_data['default']:
|
||||
print(" Failed!", flush=True)
|
||||
print("The uploaded SVM model is not the default")
|
||||
print("DATA:", data)
|
||||
return
|
||||
else:
|
||||
if svm_found:
|
||||
print(" Failed!", flush=True)
|
||||
print("The uploaded SVM model appears twice")
|
||||
print("DATA:", data)
|
||||
return
|
||||
else:
|
||||
svm_found = True
|
||||
|
||||
if not cnn_found:
|
||||
print(" Failed!", flush=True)
|
||||
print("The uploaded CNN model is missing")
|
||||
print("EXPETED:", cnn_uuid)
|
||||
print("DATA:", data)
|
||||
return
|
||||
|
||||
if not svm_found:
|
||||
print(" Failed!", flush=True)
|
||||
print("The uploaded SVM model is missing")
|
||||
print("EXPETED:", svm_uuid)
|
||||
print("DATA:", data)
|
||||
return
|
||||
|
||||
print(" Success!", flush=True)
|
||||
|
||||
print("Your Birbnetes deployment is ready!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue
Block a user