import sys, os

from cryptography.hazmat.primitives import serialization, hashes
from cryptography.hazmat.primitives.asymmetric import rsa, padding
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend

# -----------------
# encrypt functions
# -----------------

# function to generate a 256-bit symmetric key
def generate_symmetric_key():
	return os.urandom(32)


# function to encrypt data using a symmetric key
def encrypt_symmetric(key, plain_text):
	# generate a random IV
	iv = os.urandom(16)

	# cipher the data using AES in CFB mode
	cipher = Cipher(algorithms.AES(key), modes.CFB(iv), backend=default_backend())
	encryptor = cipher.encryptor()
	ciphertext = encryptor.update(plain_text) + encryptor.finalize()

	return iv + ciphertext


# function that calls and combines the symmetric and asymmetric encryption
def encrypt_hybrid(public_key, plaintext):
	# generate a random symmetric key
	symmetric_key = generate_symmetric_key()

	encrypted_data = encrypt_symmetric(symmetric_key, plaintext)

	# encrypt the symmetric key with the public key
	encrypted_symmetric_key = public_key.encrypt(
		symmetric_key,
		padding.OAEP(
			mgf=padding.MGF1(algorithm=hashes.SHA256()),
			algorithm=hashes.SHA256(),
			label=None
		)
	)

	# combine the symmetric key and the encrypted data
	return encrypted_symmetric_key + encrypted_data


# main function to encrypt the file
def encrypt_file(public_key, original_file, encrypted_file):
	with open(original_file, 'rb') as f:
		plaintext = f.read()

	encrypted_content = encrypt_hybrid(public_key, plaintext)

	with open(encrypted_file, 'wb') as f:
		f.write(encrypted_content)


# function to load a public key from a file
def load_public_key(file):
	with open(file, 'rb') as key_file:
		public_key = serialization.load_pem_public_key(
			key_file.read(),
		)
	public_key_pem = public_key.public_bytes(
        encoding=serialization.Encoding.PEM,
        format=serialization.PublicFormat.SubjectPublicKeyInfo
    ).decode('utf-8')

	return public_key_pem

# -----------------
# decrypt functions
# -----------------

# function to decrypt data using a symmetric key
def decrypt_symmetric(key, ciphertext):
	# generate a random IV
	iv = ciphertext[:16]

	# decipher the data using AES in CFB mode
	ciphertext = ciphertext[16:]
	cipher = Cipher(algorithms.AES(key), modes.CFB(iv), backend=default_backend())
	decryptor = cipher.decryptor()

	return decryptor.update(ciphertext) + decryptor.finalize()


# function that calls and combines the symmetric and asymmetric decryption
def decrypt_hybrid(private_key, encrypted_data):
	# extract the encrypted symmetric key and the encrypted data (remember that the data is symmetric + asymmetric)
	encrypted_symmetric_key = encrypted_data[:private_key.key_size // 8]
	encrypted_data = encrypted_data[private_key.key_size // 8:]

	# decrypt the symmetric key using the RSA private key
	symmetric_key = private_key.decrypt(
		encrypted_symmetric_key,
		padding.OAEP(
			mgf=padding.MGF1(algorithm=hashes.SHA256()),
			algorithm=hashes.SHA256(),
			label=None
		)
	)

	# decrypt the data using the decrypted symmetric key
	return decrypt_symmetric(symmetric_key, encrypted_data)


# main function to decrypt the file
def decrypt_file(private_key, encrypted_file, decrypted_file=None):
	with open(encrypted_file, 'rb') as f:
		encrypted_content = f.read()

	decrypted_content = decrypt_hybrid(private_key, encrypted_content)

	if decrypted_file is None:
		return decrypted_content
	else:
		with open(decrypted_file, 'wb') as f:
			f.write(decrypted_content)


# function to load a private key from a file
def load_private_key(file, passwd=None):
	if passwd is not None:
		passwd = passwd.encode('utf-8')

	try:
		with open(file, 'rb') as key_file:
			private_key = serialization.load_pem_private_key(
				key_file.read(),
				password=passwd,
			)
	except ValueError as e:
		raise ValueError("Error: The password is not valid.") from e


	return private_key