added bootstrap script
This commit is contained in:
		
							
								
								
									
										204
									
								
								bootstrap.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										204
									
								
								bootstrap.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,204 @@ | |||||||
|  | #!/usr/bin/env python3 | ||||||
|  | from typing import Optional | ||||||
|  | import requests | ||||||
|  | import os.path | ||||||
|  | import sys | ||||||
|  | import json | ||||||
|  | from urllib.parse import urljoin | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def compile_multipart_file_part(path: str) -> tuple: | ||||||
|  |     return ( | ||||||
|  |         os.path.basename(path), | ||||||
|  |         open(path, 'rb').read(), | ||||||
|  |         'application/octet-stream', | ||||||
|  |         {'Content-length': os.path.getsize(path)} | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def compile_multipart_json_part(data: dict) -> tuple: | ||||||
|  |     return ( | ||||||
|  |         None, | ||||||
|  |         json.dumps(data), | ||||||
|  |         "application/json" | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def perform_upload_and_print_result_and_get_uuid(url: str, files: dict) -> Optional[str]: | ||||||
|  |     try: | ||||||
|  |         r = requests.post(url, files=files, timeout=90) | ||||||
|  |     except requests.exceptions.RequestException as e: | ||||||
|  |         print(" Failed!", flush=True) | ||||||
|  |         print(e) | ||||||
|  |         return None | ||||||
|  |  | ||||||
|  |     if r.status_code != 200: | ||||||
|  |         print(" Failed!", flush=True) | ||||||
|  |         print("STATUS:", r.status_code) | ||||||
|  |         print("HEADERS: ", r.headers) | ||||||
|  |         print("CONTENT: ", r.content) | ||||||
|  |         return None | ||||||
|  |     else: | ||||||
|  |         response_data = r.json() | ||||||
|  |         if "id" not in response_data: | ||||||
|  |             print(" Failed!", flush=True) | ||||||
|  |             print("Invalid response: no id field!") | ||||||
|  |             print("STATUS:", r.status_code) | ||||||
|  |             print("HEADERS: ", r.headers) | ||||||
|  |             print("CONTENT: ", r.content) | ||||||
|  |             return None | ||||||
|  |  | ||||||
|  |         print(" Success!", flush=True) | ||||||
|  |         return response_data["id"] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def put_json_and_print_result(url: str, data: dict) -> bool: | ||||||
|  |     try: | ||||||
|  |         r = requests.put(url, json=data) | ||||||
|  |     except requests.exceptions.RequestException as e: | ||||||
|  |         print(" Failed!", flush=True) | ||||||
|  |         print(e) | ||||||
|  |         return False | ||||||
|  |  | ||||||
|  |     if r.status_code != 204: | ||||||
|  |         print(" Failed!", flush=True) | ||||||
|  |         print("STATUS:", r.status_code) | ||||||
|  |         print("HEADERS: ", r.headers) | ||||||
|  |         print("CONTENT: ", r.content) | ||||||
|  |         return False | ||||||
|  |     else: | ||||||
|  |         print(" Success!", flush=True) | ||||||
|  |         return True | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_and_print_failure_only_and_return_response(url: str) -> Optional[dict]: | ||||||
|  |     try: | ||||||
|  |         r = requests.get(url) | ||||||
|  |     except requests.exceptions.RequestException as e: | ||||||
|  |         print(" Failed!", flush=True) | ||||||
|  |         print(e) | ||||||
|  |         return None | ||||||
|  |  | ||||||
|  |     if r.status_code != 200: | ||||||
|  |         print(" Failed!", flush=True) | ||||||
|  |         print("STATUS:", r.status_code) | ||||||
|  |         print("HEADERS: ", r.headers) | ||||||
|  |         print("CONTENT: ", r.content) | ||||||
|  |         return None | ||||||
|  |     else: | ||||||
|  |         response_data = r.json() | ||||||
|  |         return response_data | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def main(): | ||||||
|  |     if len(sys.argv) != 2: | ||||||
|  |         print("Usage: bootstrap.py [API_BASE]") | ||||||
|  |         return | ||||||
|  |  | ||||||
|  |     api_base = sys.argv[1] | ||||||
|  |     basepath = os.path.dirname(os.path.abspath(__file__)) | ||||||
|  |  | ||||||
|  |     print(f"Bootstrapping Birbnetes deployment at {api_base} with models in {basepath}...") | ||||||
|  |  | ||||||
|  |     print("[1/5] Uploading CNN model...", end="", flush=True) | ||||||
|  |  | ||||||
|  |     # Upload CNN first | ||||||
|  |     cnn_modelFile = os.path.join(basepath, "models/cnn/model_batch_590.json") | ||||||
|  |     cnn_weightsFile = os.path.join(basepath, "models/cnn/best_model_batch_590.h5") | ||||||
|  |  | ||||||
|  |     files = { | ||||||
|  |         "modelFile": compile_multipart_file_part(cnn_modelFile), | ||||||
|  |         "weightsFile": compile_multipart_file_part(cnn_weightsFile), | ||||||
|  |         "info": compile_multipart_json_part({"target_class_name": "sturnus"}) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     cnn_uuid = perform_upload_and_print_result_and_get_uuid(urljoin(api_base, "model/cnn"), files) | ||||||
|  |  | ||||||
|  |     if not cnn_uuid: | ||||||
|  |         return | ||||||
|  |  | ||||||
|  |     print("[2/5] Uploading SVM model...", end="", flush=True) | ||||||
|  |  | ||||||
|  |     # Upload SVM model | ||||||
|  |     svm_modelFile = os.path.join(basepath, "models/svm/svm_8_500") | ||||||
|  |     svm_meansFile = os.path.join(basepath, "models/svm/svm_8_500MEANS") | ||||||
|  |  | ||||||
|  |     files = { | ||||||
|  |         "modelFile": compile_multipart_file_part(svm_modelFile), | ||||||
|  |         "meansFile": compile_multipart_file_part(svm_meansFile), | ||||||
|  |         "info": compile_multipart_json_part({"target_class_name": "Chirp"}) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     svm_uuid = perform_upload_and_print_result_and_get_uuid(urljoin(api_base, "model/svm"), files) | ||||||
|  |  | ||||||
|  |     if not svm_uuid: | ||||||
|  |         return | ||||||
|  |  | ||||||
|  |     print("[3/5] Setting default CNN model...", end="", flush=True) | ||||||
|  |     if not put_json_and_print_result(urljoin(api_base, "model/cnn/$default"), {"id": cnn_uuid}): | ||||||
|  |         return | ||||||
|  |  | ||||||
|  |     print("[4/5] Setting default SVM model...", end="", flush=True) | ||||||
|  |     if not put_json_and_print_result(urljoin(api_base, "model/svm/$default"), {"id": svm_uuid}): | ||||||
|  |         return | ||||||
|  |  | ||||||
|  |     print("[5/5] Validating...", end="", flush=True) | ||||||
|  |     data = get_and_print_failure_only_and_return_response(urljoin(api_base, "model")) | ||||||
|  |  | ||||||
|  |     if not data: | ||||||
|  |         return | ||||||
|  |  | ||||||
|  |     svm_found = False | ||||||
|  |     cnn_found = False | ||||||
|  |     for model_data in data: | ||||||
|  |         if model_data['id'] == cnn_uuid: | ||||||
|  |             if not model_data['default']: | ||||||
|  |                 print(" Failed!", flush=True) | ||||||
|  |                 print("The uploaded CNN model is not the default") | ||||||
|  |                 print("DATA:", data) | ||||||
|  |                 return | ||||||
|  |             else: | ||||||
|  |                 if cnn_found: | ||||||
|  |                     print(" Failed!", flush=True) | ||||||
|  |                     print("The uploaded CNN model appears twice") | ||||||
|  |                     print("DATA:", data) | ||||||
|  |                     return | ||||||
|  |                 else: | ||||||
|  |                     cnn_found = True | ||||||
|  |  | ||||||
|  |         if model_data['id'] == svm_uuid: | ||||||
|  |             if not model_data['default']: | ||||||
|  |                 print(" Failed!", flush=True) | ||||||
|  |                 print("The uploaded SVM model is not the default") | ||||||
|  |                 print("DATA:", data) | ||||||
|  |                 return | ||||||
|  |             else: | ||||||
|  |                 if svm_found: | ||||||
|  |                     print(" Failed!", flush=True) | ||||||
|  |                     print("The uploaded SVM model appears twice") | ||||||
|  |                     print("DATA:", data) | ||||||
|  |                     return | ||||||
|  |                 else: | ||||||
|  |                     svm_found = True | ||||||
|  |  | ||||||
|  |     if not cnn_found: | ||||||
|  |         print(" Failed!", flush=True) | ||||||
|  |         print("The uploaded CNN model is missing") | ||||||
|  |         print("EXPETED:", cnn_uuid) | ||||||
|  |         print("DATA:", data) | ||||||
|  |         return | ||||||
|  |  | ||||||
|  |     if not svm_found: | ||||||
|  |         print(" Failed!", flush=True) | ||||||
|  |         print("The uploaded SVM model is missing") | ||||||
|  |         print("EXPETED:", svm_uuid) | ||||||
|  |         print("DATA:", data) | ||||||
|  |         return | ||||||
|  |  | ||||||
|  |     print(" Success!", flush=True) | ||||||
|  |  | ||||||
|  |     print("Your Birbnetes deployment is ready!") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == '__main__': | ||||||
|  |     main() | ||||||
		Reference in New Issue
	
	Block a user