diff --git a/requirements.txt b/requirements.txt index c21b6ec..5da17aa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ -pycryptodome \ No newline at end of file +pycryptodome +pydh \ No newline at end of file diff --git a/server/netwrapper.py b/server/netwrapper.py index f97f45b..a7919f9 100644 --- a/server/netwrapper.py +++ b/server/netwrapper.py @@ -1,22 +1,133 @@ #!/usr/bin/env python3 +import json +from base64 import b64encode, b64decode +import pyDH +from Crypto.Cipher import PKCS1_OAEP +from Crypto.PublicKey import RSA +from Crypto.Cipher import ChaCha20 +from Crypto.Random import get_random_bytes + +from netsim import network_interface +from server.server import Server + class NetWrapper: - def __init__(self, publicKey: str, privateKey: str, cipherKey: str = ""): - self.publicKey = publicKey - self.privateKey = privateKey - self.cipherKey = cipherKey + def __init__(self, clientPublicKey: dict, serverPrivateKey: str, serverInstance: Server): + self.clientPublicKey = clientPublicKey + self.currentClientPublicKey = "".encode('UTF-8') + self.serverPrivateKey = serverPrivateKey + self.cipherkey = "".encode('UTF-8') + self.network = network_interface('./', 'A') + self.clientAddr = "" + self.currentUser = "" + self.serverInstance = serverInstance - def identifyServer(self,message: bytes): - #Message is coded with Server RSA public key in string byte format - #Create network_interface - pass + def serverIdentify(self, msg: bytes) -> None: + incommingJson = json.loads(msg.decode('UTF-8')) + self.clientAddr = incommingJson['source'] + self.currentUser = incommingJson['username'] + self.currentClientPublicKey = self.clientPublicKey[self.currentUser] + myrsakey = RSA.import_key(self.serverPrivateKey) + mycipher = PKCS1_OAEP.new(myrsakey) + retmsg = mycipher.decrypt(b64decode(incommingJson['message'])).decode('UTF-8') + rsakey = RSA.import_key(self.currentClientPublicKey) + cipher = PKCS1_OAEP.new(rsakey) + identMsg = json.dumps( + {'type': 'IDY', 'source': self.network.own_addr, + 'message': b64encode(cipher.encrypt(retmsg.encode('UTF-8')))}).encode( + 'UTF-8') + self.network.send_msg(self.clientAddr, identMsg) - def identifyCleint(self): - pass + def sendMessage(self, message: bytes) -> None: + cipher = ChaCha20.new(self.cipherkey, get_random_bytes(12)) + ciphertext = cipher.encrypt(message) + nonce = b64encode(cipher.nonce).decode('UTF-8') + ct = b64encode(ciphertext).decode('UTF-8') + sendjson = json.dumps({'type': 'CMD', 'source': self.network.own_addr, 'nonce': nonce, 'message': ct}).encode( + 'UTF-8') + self.network.send_msg(self.clientAddr, sendjson) - def sendMessage(self): - pass + def keyExchange(self) -> None: + dh = pyDH.DiffieHellman() + rsakey = RSA.import_key(self.currentClientPublicKey) + cipher = PKCS1_OAEP.new(rsakey) + mypubkey = b64encode(cipher.encrypt(str(dh.gen_public_key()).encode('UTF-8'))) + jsonmsg = json.dumps({'type': 'DH', 'source': self.network.own_addr, 'message': mypubkey}).encode('UTF-8') + self.network.send_msg(self.clientAddr, jsonmsg) + decodedmsg = {'source': '', 'type': ''} + while not (decodedmsg['source'] == self.clientAddr and decodedmsg['type'] == 'DH'): + status, msg = self.network.receive_msg(blocking=True) + if not status: + raise Exception('Network error during connection.') + decodedmsg = json.loads(msg.decode('UTF-8')) + myrsakey = RSA.import_key(self.serverPrivateKey) + mycipher = PKCS1_OAEP.new(myrsakey) + serverpubkey = int(mycipher.decrypt(b64decode(decodedmsg['message'])).decode('UTF-8')) + self.cipherkey = dh.gen_shared_key(serverpubkey).encode('UTF-8') - def recieveMessage(self): - pass + def login(self) -> bool: + b64 = {'source': '', 'type': ''} + while not (b64['source'] == self.clientAddr and b64['type'] == 'AUT'): + status, msg = self.network.receive_msg(blocking=True) + if not status: + raise Exception('Network error during connection.') + b64 = json.loads(msg.decode('UTF-8')) + try: + retnonce = b64decode(b64['nonce']) + retciphertext = b64decode(b64['message']) + retcipher = ChaCha20.new(self.cipherkey, nonce=retnonce) + plaintext = retcipher.decrypt(retciphertext).decode('UTF-8').split(' ') + linsuccess = (not (len(plaintext) != 3 or plaintext[0] != "LIN" or plaintext[1] != self.currentUser)) and self.serverInstance.login(plaintext[1],plaintext[2]) + if linsuccess: + message = "OK".encode('UTF-8') + else: + message = "ERROR".encode('UTF-8') + cipher = ChaCha20.new(self.cipherkey, get_random_bytes(12)) + ciphertext = cipher.encrypt(message) + nonce = b64encode(cipher.nonce).decode('UTF-8') + ct = b64encode(ciphertext).decode('UTF-8') + sendjson = json.dumps({'type': 'AUT', 'source': self.network.own_addr, 'nonce': nonce, 'message': ct}).encode( + 'UTF-8') + self.network.send_msg(self.clientAddr, sendjson) + return linsuccess + except Exception: + print("Incorrect decryption") + + def initClientConnection(self, msg: bytes) -> bytes: + try: + self.serverIdentify(msg) + self.keyExchange() + success = self.login() + if success: + return b"LINOK" + else: + self.logout() + except Exception: + self.logout() + + def recieveMessage(self) -> bytes: + status, msg = self.network.receive_msg(blocking=True) + if not status: + raise Exception('Network error during connection.') + if not self.clientAddr: + return self.initClientConnection(msg) + else: + return self.recieveEncryptedMessage(msg) + + def logout(self) -> None: + self.clientAddr = "" + self.cipherkey = "".encode('UTF-8') + self.currentClientPublicKey = "".encode('UTF-8') + self.currentUser = "" + + def recieveEncryptedMessage(self, msg: bytes) -> bytes: + try: + b64 = json.loads(msg) + retnonce = b64decode(b64['nonce']) + retciphertext = b64decode(b64['message']) + retcipher = ChaCha20.new(self.cipherkey, nonce=retnonce) + plaintext = retcipher.decrypt(retciphertext) + return plaintext + except Exception: + print("Incorrect decryption")