#!/usr/bin/env python3 from typing import Optional import requests import os.path import sys import json from urllib.parse import urljoin def compile_multipart_file_part(path: str) -> tuple: return ( os.path.basename(path), open(path, 'rb').read(), 'application/octet-stream', {'Content-length': os.path.getsize(path)} ) def compile_multipart_json_part(data: dict) -> tuple: return ( None, json.dumps(data), "application/json" ) def perform_upload_and_print_result_and_get_uuid(url: str, files: dict) -> Optional[str]: try: r = requests.post(url, files=files, timeout=90) except requests.exceptions.RequestException as e: print(" Failed!", flush=True) print(e) return None if r.status_code != 200: print(" Failed!", flush=True) print("STATUS:", r.status_code) print("HEADERS: ", r.headers) print("CONTENT: ", r.content) return None else: response_data = r.json() if "id" not in response_data: print(" Failed!", flush=True) print("Invalid response: no id field!") print("STATUS:", r.status_code) print("HEADERS: ", r.headers) print("CONTENT: ", r.content) return None print(" Success!", flush=True) return response_data["id"] def put_json_and_print_result(url: str, data: dict) -> bool: try: r = requests.put(url, json=data) except requests.exceptions.RequestException as e: print(" Failed!", flush=True) print(e) return False if r.status_code != 204: print(" Failed!", flush=True) print("STATUS:", r.status_code) print("HEADERS: ", r.headers) print("CONTENT: ", r.content) return False else: print(" Success!", flush=True) return True def get_and_print_failure_only_and_return_response(url: str) -> Optional[dict]: try: r = requests.get(url) except requests.exceptions.RequestException as e: print(" Failed!", flush=True) print(e) return None if r.status_code != 200: print(" Failed!", flush=True) print("STATUS:", r.status_code) print("HEADERS: ", r.headers) print("CONTENT: ", r.content) return None else: response_data = r.json() return response_data def main(): if len(sys.argv) != 2: print("Usage: bootstrap.py [API_BASE]") return api_base = sys.argv[1] basepath = os.path.dirname(os.path.abspath(__file__)) print(f"Bootstrapping Birbnetes deployment at {api_base} with models in {basepath}...") print("[1/5] Uploading CNN model...", end="", flush=True) # Upload CNN first cnn_modelFile = os.path.join(basepath, "models/cnn/model_batch_590.json") cnn_weightsFile = os.path.join(basepath, "models/cnn/best_model_batch_590.h5") files = { "modelFile": compile_multipart_file_part(cnn_modelFile), "weightsFile": compile_multipart_file_part(cnn_weightsFile), "info": compile_multipart_json_part({"target_class_name": "sturnus"}) } cnn_uuid = perform_upload_and_print_result_and_get_uuid(urljoin(api_base, "model/cnn"), files) if not cnn_uuid: return print("[2/5] Uploading SVM model...", end="", flush=True) # Upload SVM model svm_modelFile = os.path.join(basepath, "models/svm/svm_8_500") svm_meansFile = os.path.join(basepath, "models/svm/svm_8_500MEANS") files = { "modelFile": compile_multipart_file_part(svm_modelFile), "meansFile": compile_multipart_file_part(svm_meansFile), "info": compile_multipart_json_part({"target_class_name": "Chirp"}) } svm_uuid = perform_upload_and_print_result_and_get_uuid(urljoin(api_base, "model/svm"), files) if not svm_uuid: return print("[3/5] Setting default CNN model...", end="", flush=True) if not put_json_and_print_result(urljoin(api_base, "model/cnn/$default"), {"id": cnn_uuid}): return print("[4/5] Setting default SVM model...", end="", flush=True) if not put_json_and_print_result(urljoin(api_base, "model/svm/$default"), {"id": svm_uuid}): return print("[5/5] Validating...", end="", flush=True) data = get_and_print_failure_only_and_return_response(urljoin(api_base, "model")) if not data: return svm_found = False cnn_found = False for model_data in data: if model_data['id'] == cnn_uuid: if not model_data['default']: print(" Failed!", flush=True) print("The uploaded CNN model is not the default") print("DATA:", data) return else: if cnn_found: print(" Failed!", flush=True) print("The uploaded CNN model appears twice") print("DATA:", data) return else: cnn_found = True if model_data['id'] == svm_uuid: if not model_data['default']: print(" Failed!", flush=True) print("The uploaded SVM model is not the default") print("DATA:", data) return else: if svm_found: print(" Failed!", flush=True) print("The uploaded SVM model appears twice") print("DATA:", data) return else: svm_found = True if not cnn_found: print(" Failed!", flush=True) print("The uploaded CNN model is missing") print("EXPETED:", cnn_uuid) print("DATA:", data) return if not svm_found: print(" Failed!", flush=True) print("The uploaded SVM model is missing") print("EXPETED:", svm_uuid) print("DATA:", data) return print(" Success!", flush=True) print("Your Birbnetes deployment is ready!") if __name__ == '__main__': main()