from cryptography.hazmat.primitives.asymmetric import dh from cryptography.hazmat.primitives import serialization, hashes from cryptography.hazmat.primitives.kdf.hkdf import HKDF from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.backends import default_backend import subprocess import os def generate_parameters(generator: int, key_size: int, backend=default_backend()): """ Generate the parameters for the Diffie-Hellman key exchange :param generator: :param key_size: :param backend: :return: """ if key_size < 512: raise ValueError("Key size must be at least 512 bits") if generator not in [2, 5]: raise ValueError("Generator must be 2 or 5") return dh.generate_parameters(generator, key_size, backend) def generate_key_pair(parameters: dh.DHParameters): """ Generate a key pair for the Diffie-Hellman key exchange :param parameters: :return private_key, public_key: """ private_key = parameters.generate_private_key() return private_key, private_key.public_key() def derive_keys(private_key: dh.DHPrivateKey, peer_public_key: dh.DHPublicKey): """ Derive the shared key from the private key and the peer's public key :param private_key: :param peer_public_key: :return private_key, derived_key: """ shared_key = private_key.exchange(peer_public_key) derived_key = HKDF( algorithm=hashes.SHA256(), length=32, salt=None, info=b'handshake data', backend=default_backend() ).derive(shared_key) print("Successfully derived key") return derived_key # Encrypt content using the derived key def encrypt(content, key): iv = os.urandom(16) cipher = Cipher(algorithms.AES(key), modes.CFB(iv), backend=default_backend()) encryptor = cipher.encryptor() ciphertext = iv + encryptor.update(content) + encryptor.finalize() return ciphertext # Decrypt content using the derived key def decrypt(ciphertext, key): iv = ciphertext[:16] cipher = Cipher(algorithms.AES(key), modes.CFB(iv), backend=default_backend()) decryptor = cipher.decryptor() plaintext = decryptor.update(ciphertext[16:]) + decryptor.finalize() return plaintext def load_dh_public_key(file): with open(file, 'rb') as key_file: public_key = serialization.load_pem_public_key( key_file.read(), ) return public_key def load_dh_private_key(file, passwd=None): if passwd is not None: passwd = passwd.encode('utf-8') try: with open(file, 'rb') as key_file: private_key = serialization.load_pem_private_key( key_file.read(), password=passwd, ) except ValueError as e: raise ValueError("Error: The password is not valid.") from e return private_key def diffie_hellman(): # imagine that the agreement was done beforehand generator = 2; key_size = 1024 parameters = generate_parameters(generator, key_size) # generate keys for client and server client_private_key, client_public_key = generate_key_pair(parameters) server_private_key, server_public_key = generate_key_pair(parameters) # derive keys client_derived_key = derive_keys(client_private_key, server_public_key) server_derived_key = derive_keys(server_private_key, client_public_key) print("Client derived key: ", client_derived_key) print("Server derived key: ", server_derived_key) # write the keys to files with open("client_private_key.pem", 'wb') as f: f.write(client_private_key.private_bytes(encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.PKCS8, encryption_algorithm=serialization.NoEncryption())) with open("server_private_key.pem", 'wb') as f: f.write(server_private_key.private_bytes(encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.PKCS8, encryption_algorithm=serialization.NoEncryption())) with open("client_derived_key.pem", 'wb') as f: f.write(client_derived_key) with open("server_derived_key.pem", 'wb') as f: f.write(server_derived_key) print(f"Client private key: \n\n{client_private_key}\nServer derived key: \n\n{server_derived_key}\n") print(f"Server private key: \n\n{server_private_key}\nClient derived key: \n\n{client_derived_key}\n") # test encryption process = subprocess.Popen(f"dd if=/dev/zero of=file.txt bs=1024 count=1000", shell=True) with open("file.txt", 'rb') as f: data = f.read() encrypted_data = encrypt(data, client_derived_key) decrypted_data = decrypt(encrypted_data, server_derived_key) if data == decrypted_data: print("Encryption and decryption successful!")