diff --git a/bootstrap.py b/bootstrap.py new file mode 100644 index 0000000..f727fd3 --- /dev/null +++ b/bootstrap.py @@ -0,0 +1,204 @@ +#!/usr/bin/env python3 +from typing import Optional +import requests +import os.path +import sys +import json +from urllib.parse import urljoin + + +def compile_multipart_file_part(path: str) -> tuple: + return ( + os.path.basename(path), + open(path, 'rb').read(), + 'application/octet-stream', + {'Content-length': os.path.getsize(path)} + ) + + +def compile_multipart_json_part(data: dict) -> tuple: + return ( + None, + json.dumps(data), + "application/json" + ) + + +def perform_upload_and_print_result_and_get_uuid(url: str, files: dict) -> Optional[str]: + try: + r = requests.post(url, files=files, timeout=90) + except requests.exceptions.RequestException as e: + print(" Failed!", flush=True) + print(e) + return None + + if r.status_code != 200: + print(" Failed!", flush=True) + print("STATUS:", r.status_code) + print("HEADERS: ", r.headers) + print("CONTENT: ", r.content) + return None + else: + response_data = r.json() + if "id" not in response_data: + print(" Failed!", flush=True) + print("Invalid response: no id field!") + print("STATUS:", r.status_code) + print("HEADERS: ", r.headers) + print("CONTENT: ", r.content) + return None + + print(" Success!", flush=True) + return response_data["id"] + + +def put_json_and_print_result(url: str, data: dict) -> bool: + try: + r = requests.put(url, json=data) + except requests.exceptions.RequestException as e: + print(" Failed!", flush=True) + print(e) + return False + + if r.status_code != 204: + print(" Failed!", flush=True) + print("STATUS:", r.status_code) + print("HEADERS: ", r.headers) + print("CONTENT: ", r.content) + return False + else: + print(" Success!", flush=True) + return True + + +def get_and_print_failure_only_and_return_response(url: str) -> Optional[dict]: + try: + r = requests.get(url) + except requests.exceptions.RequestException as e: + print(" Failed!", flush=True) + print(e) + return None + + if r.status_code != 200: + print(" Failed!", flush=True) + print("STATUS:", r.status_code) + print("HEADERS: ", r.headers) + print("CONTENT: ", r.content) + return None + else: + response_data = r.json() + return response_data + + +def main(): + if len(sys.argv) != 2: + print("Usage: bootstrap.py [API_BASE]") + return + + api_base = sys.argv[1] + basepath = os.path.dirname(os.path.abspath(__file__)) + + print(f"Bootstrapping Birbnetes deployment at {api_base} with models in {basepath}...") + + print("[1/5] Uploading CNN model...", end="", flush=True) + + # Upload CNN first + cnn_modelFile = os.path.join(basepath, "models/cnn/model_batch_590.json") + cnn_weightsFile = os.path.join(basepath, "models/cnn/best_model_batch_590.h5") + + files = { + "modelFile": compile_multipart_file_part(cnn_modelFile), + "weightsFile": compile_multipart_file_part(cnn_weightsFile), + "info": compile_multipart_json_part({"target_class_name": "sturnus"}) + } + + cnn_uuid = perform_upload_and_print_result_and_get_uuid(urljoin(api_base, "model/cnn"), files) + + if not cnn_uuid: + return + + print("[2/5] Uploading SVM model...", end="", flush=True) + + # Upload SVM model + svm_modelFile = os.path.join(basepath, "models/svm/svm_8_500") + svm_meansFile = os.path.join(basepath, "models/svm/svm_8_500MEANS") + + files = { + "modelFile": compile_multipart_file_part(svm_modelFile), + "meansFile": compile_multipart_file_part(svm_meansFile), + "info": compile_multipart_json_part({"target_class_name": "Chirp"}) + } + + svm_uuid = perform_upload_and_print_result_and_get_uuid(urljoin(api_base, "model/svm"), files) + + if not svm_uuid: + return + + print("[3/5] Setting default CNN model...", end="", flush=True) + if not put_json_and_print_result(urljoin(api_base, "model/cnn/$default"), {"id": cnn_uuid}): + return + + print("[4/5] Setting default SVM model...", end="", flush=True) + if not put_json_and_print_result(urljoin(api_base, "model/svm/$default"), {"id": svm_uuid}): + return + + print("[5/5] Validating...", end="", flush=True) + data = get_and_print_failure_only_and_return_response(urljoin(api_base, "model")) + + if not data: + return + + svm_found = False + cnn_found = False + for model_data in data: + if model_data['id'] == cnn_uuid: + if not model_data['default']: + print(" Failed!", flush=True) + print("The uploaded CNN model is not the default") + print("DATA:", data) + return + else: + if cnn_found: + print(" Failed!", flush=True) + print("The uploaded CNN model appears twice") + print("DATA:", data) + return + else: + cnn_found = True + + if model_data['id'] == svm_uuid: + if not model_data['default']: + print(" Failed!", flush=True) + print("The uploaded SVM model is not the default") + print("DATA:", data) + return + else: + if svm_found: + print(" Failed!", flush=True) + print("The uploaded SVM model appears twice") + print("DATA:", data) + return + else: + svm_found = True + + if not cnn_found: + print(" Failed!", flush=True) + print("The uploaded CNN model is missing") + print("EXPETED:", cnn_uuid) + print("DATA:", data) + return + + if not svm_found: + print(" Failed!", flush=True) + print("The uploaded SVM model is missing") + print("EXPETED:", svm_uuid) + print("DATA:", data) + return + + print(" Success!", flush=True) + + print("Your Birbnetes deployment is ready!") + + +if __name__ == '__main__': + main()