diff --git a/coti/crypto_utils.py b/coti/crypto_utils.py index 7cc5f15..649f085 100644 --- a/coti/crypto_utils.py +++ b/coti/crypto_utils.py @@ -5,6 +5,7 @@ from cryptography.hazmat.primitives.asymmetric import padding from cryptography.hazmat.primitives.asymmetric import rsa from eth_keys import keys +from .types import ItString, ItUint block_size = AES.block_size address_size = 20 @@ -13,17 +14,17 @@ key_size = 32 -def encrypt(key, plaintext): +def encrypt(user_aes_key: bytes, plaintext: int): # Ensure plaintext is smaller than 128 bits (16 bytes) if len(plaintext) > block_size: raise ValueError("Plaintext size must be 128 bits or smaller.") # Ensure key size is 128 bits (16 bytes) - if len(key) != block_size: + if len(user_aes_key) != block_size: raise ValueError("Key size must be 128 bits.") # Create a new AES cipher block using the provided key - cipher = AES.new(key, AES.MODE_ECB) + cipher = AES.new(user_aes_key, AES.MODE_ECB) # Generate a random value 'r' of the same length as the block size r = get_random_bytes(block_size) @@ -40,12 +41,12 @@ def encrypt(key, plaintext): return ciphertext, r -def decrypt(key, r, ciphertext): +def decrypt(user_aes_key: bytes, r: bytes, ciphertext: bytes): if len(ciphertext) != block_size: raise ValueError("Ciphertext size must be 128 bits.") # Ensure key size is 128 bits (16 bytes) - if len(key) != block_size: + if len(user_aes_key) != block_size: raise ValueError("Key size must be 128 bits.") # Ensure random size is 128 bits (16 bytes) @@ -53,7 +54,7 @@ def decrypt(key, r, ciphertext): raise ValueError("Random size must be 128 bits.") # Create a new AES cipher block using the provided key - cipher = AES.new(key, AES.MODE_ECB) + cipher = AES.new(user_aes_key, AES.MODE_ECB) # Encrypt the random value 'r' using AES in ECB mode encrypted_r = cipher.encrypt(r) @@ -71,14 +72,14 @@ def generate_aes_key(): return key -def sign_input_text(sender, addr, function_selector, ct, key): +def sign_input_text(sender_address: str, contract_address: str, function_selector: str, ct, key): function_selector_bytes = bytes.fromhex(function_selector[2:]) # Ensure all input sizes are the correct length - if len(sender) != address_size: - raise ValueError(f"Invalid sender address length: {len(sender)} bytes, must be {address_size} bytes") - if len(addr) != address_size: - raise ValueError(f"Invalid contract address length: {len(addr)} bytes, must be {address_size} bytes") + if len(sender_address) != address_size: + raise ValueError(f"Invalid sender address length: {len(sender_address)} bytes, must be {address_size} bytes") + if len(contract_address) != address_size: + raise ValueError(f"Invalid contract address length: {len(contract_address)} bytes, must be {address_size} bytes") if len(function_selector_bytes) != function_selector_size: raise ValueError(f"Invalid signature size: {len(function_selector_bytes)} bytes, must be {function_selector_size} bytes") if len(ct) != ct_size: @@ -88,7 +89,7 @@ def sign_input_text(sender, addr, function_selector, ct, key): raise ValueError(f"Invalid key length: {len(key)} bytes, must be {key_size} bytes") # Create the message to be signed by appending all inputs - message = sender + addr + function_selector_bytes + ct + message = sender_address + contract_address + function_selector_bytes + ct return sign(message, key) @@ -101,19 +102,16 @@ def sign(message, key): return signature -def build_input_text(plaintext, user_aes_key, sender, contract, function_selector, signing_key): - sender_address_bytes = bytes.fromhex(sender.address[2:]) - contract_address_bytes = bytes.fromhex(contract.address[2:]) - +def build_input_text(plaintext: int, user_aes_key: str, sender_address: str, contract_address: str, function_selector: str, signing_key: str) -> ItUint: # Convert the integer to a byte slice with size aligned to 8. plaintext_bytes = plaintext.to_bytes((plaintext.bit_length() + 7) // 8, 'big') # Encrypt the plaintext with the user's AES key - ciphertext, r = encrypt(user_aes_key, plaintext_bytes) + ciphertext, r = encrypt(bytes.fromhex(user_aes_key), plaintext_bytes) ct = ciphertext + r # Sign the message - signature = sign_input_text(sender_address_bytes, contract_address_bytes, function_selector, ct, signing_key) + signature = sign_input_text(bytes.fromhex(sender_address[2:]), bytes.fromhex(contract_address[2:]), function_selector, ct, signing_key) # Convert the ct to an integer int_cipher_text = int.from_bytes(ct, byteorder='big') @@ -124,7 +122,7 @@ def build_input_text(plaintext, user_aes_key, sender, contract, function_selecto } -def build_string_input_text(plaintext, user_aes_key, sender, contract, function_selector, signing_key): +def build_string_input_text(plaintext: int, user_aes_key: str, sender_address: str, contract_address: str, function_selector: str, signing_key: str) -> ItString: input_text = { 'ciphertext': { 'value': [] @@ -142,8 +140,8 @@ def build_string_input_text(plaintext, user_aes_key, sender, contract, function_ it_int = build_input_text( int.from_bytes(byte_arr, 'big'), user_aes_key, - sender, - contract, + sender_address, + contract_address, function_selector, signing_key ) @@ -154,7 +152,7 @@ def build_string_input_text(plaintext, user_aes_key, sender, contract, function_ return input_text -def decrypt_uint(ciphertext, user_key): +def decrypt_uint(ciphertext: int, user_aes_key: str) -> int: # Convert ct to bytes (big-endian) byte_array = ciphertext.to_bytes(32, byteorder='big') @@ -163,7 +161,7 @@ def decrypt_uint(ciphertext, user_key): r = byte_array[block_size:] # Decrypt the cipher - decrypted_message = decrypt(user_key, r, cipher) + decrypted_message = decrypt(bytes.fromhex(user_aes_key), r, cipher) # Print the decrypted cipher decrypted_uint = int.from_bytes(decrypted_message, 'big') @@ -171,7 +169,7 @@ def decrypt_uint(ciphertext, user_key): return decrypted_uint -def decrypt_string(ciphertext, user_key): +def decrypt_string(ciphertext: int, user_aes_key: str) -> str: if 'value' in ciphertext or hasattr(ciphertext, 'value'): # format when reading ciphertext from an event __ciphertext = ciphertext['value'] elif isinstance(ciphertext, tuple): # format when reading ciphertext from state variable @@ -182,7 +180,7 @@ def decrypt_string(ciphertext, user_key): decrypted_string = "" for value in __ciphertext: - decrypted = decrypt_uint(value, user_key) + decrypted = decrypt_uint(value, user_aes_key) byte_length = (decrypted.bit_length() + 7) // 8 # calculate the byte length @@ -219,7 +217,7 @@ def generate_rsa_keypair(): return private_key_bytes, public_key_bytes -def decrypt_rsa(private_key_bytes, ciphertext): +def decrypt_rsa(private_key_bytes: bytes, ciphertext: bytes): # Load private key private_key = serialization.load_der_private_key(private_key_bytes, password=None) # Decrypt ciphertext @@ -235,7 +233,7 @@ def decrypt_rsa(private_key_bytes, ciphertext): #This function recovers a user's key by decrypting two encrypted key shares with the given private key, #and then XORing the two key shares together. -def recover_user_key(private_key_bytes, encrypted_key_share0, encrypted_key_share1): +def recover_user_key(private_key_bytes: bytes, encrypted_key_share0: bytes, encrypted_key_share1: bytes): key_share0 = decrypt_rsa(private_key_bytes, encrypted_key_share0) key_share1 = decrypt_rsa(private_key_bytes, encrypted_key_share1) diff --git a/coti/types.py b/coti/types.py new file mode 100644 index 0000000..4be834f --- /dev/null +++ b/coti/types.py @@ -0,0 +1,20 @@ +from typing import List, TypeAlias, TypedDict + +CtBool: TypeAlias = int + +CtUint: TypeAlias = int + +class CtString(TypedDict): + value: List[int] + +class ItBool(TypedDict): + ciphertext: int + signature: bytes + +class ItUint(TypedDict): + ciphertext: int + signature: bytes + +class ItString(TypedDict): + ciphertext: CtString + signature: List[bytes] \ No newline at end of file diff --git a/coti/utils.py b/coti/utils.py deleted file mode 100644 index 6e7149e..0000000 --- a/coti/utils.py +++ /dev/null @@ -1,149 +0,0 @@ -import os - -from eth_account import Account -from web3 import Web3 -from web3.middleware import geth_poa_middleware - -SOLC_VERSION = '0.8.19' - - -def web3_connected(web3): - return web3.is_connected() - - -def print_network_details(web3): - print('provider: ', web3.eth.w3.provider.endpoint_uri) - print('chain-id: ', web3.eth.chain_id) - print('latest block: ', web3.eth.get_block('latest').get('hash').hex()) - - -def print_account_details(web3): - print('account address:', web3.eth.default_account.address) - print('account balance: ', get_native_balance(web3), 'wei (', web3.from_wei(get_native_balance(web3), 'ether'), - ' ether)') - print('account nonce: ', get_nonce(web3)) - - -def init_web3(node_https_address, eoa, error_not_connected=True): - web3 = Web3(Web3.HTTPProvider(node_https_address)) - web3.eth.default_account = eoa - web3.middleware_onion.inject(geth_poa_middleware, layer=0) - if error_not_connected: - if not web3_connected(web3): - raise Exception("Connection to node failed!") - print_network_details(web3) - print_account_details(web3) - return web3 - - -def get_eoa(account_private_key): - eoa = Account.from_key(account_private_key) - if address_valid(eoa.address): - return eoa - raise Exception("eoa from private key is not valid!") - - -def validate_address(address): - return {'valid': Web3.is_address(address), 'safe': Web3.to_checksum_address(address)} - - -def get_latest_block(web3): - return web3.eth.get_block('latest') - - -def get_nonce(web3): - return web3.eth.get_transaction_count(web3.eth.default_account.address) - - -def get_address_valid_and_checksum(address): - return {'valid': Web3.is_address(address), 'safe': Web3.to_checksum_address(address)} - - -def address_valid(address): - return get_address_valid_and_checksum(address)['valid'] - - -def get_native_balance(web3, address=None): - if address is None: - return web3.eth.get_balance(web3.eth.default_account.address) - if address is not None and address_valid(address): - return web3.eth.get_balance(address) - if not address_valid(address): - raise Exception('address ', address, ' is not valid!') - - -def load_contract(file_path): - # Ensure the file path is valid - if not os.path.exists(file_path): - raise Exception(f"The file {file_path} does not exists") - - # Read the Solidity source code from the file - with open(file_path, 'r') as file: - return file.read() - - -def transfer_native(web3, recipient_address, private_key, amount_to_transfer_ether, native_gas_units): - tx = { - 'to': recipient_address, - 'from': web3.eth.default_account.address, - 'value': web3.to_wei(amount_to_transfer_ether, 'ether'), # Transaction value (0.1 Ether in this example) - 'nonce': get_nonce(web3), - 'gas': native_gas_units, # Gas limit for the transaction - 'gasPrice': web3.eth.gas_price, - 'chainId': web3.eth.chain_id - } - validate_gas_estimation(web3, tx) - tx_receipt = sign_and_send_tx(web3, private_key, tx) - return tx_receipt - - -def validate_gas_estimation(web3, tx): - valid, gas_estimate = is_gas_units_estimation_valid(web3, tx) - if valid is False: - raise Exception('not enough gas for tx (provided: ' + str(tx.get('gas')) + ' needed by estimation: ' + str( - gas_estimate) + ')') - - -def is_gas_units_estimation_valid(web3, tx): - estimate_gas = web3.eth.estimate_gas(tx) - if tx['gas'] >= estimate_gas: - return True, estimate_gas - return False, estimate_gas - - -def deploy_contract(contract, kwargs, tx_params): - func = contract.constructor(**kwargs) - return exec_func_via_transaction(func, tx_params) - - -def exec_func_via_transaction(func, tx_params): - web3 = tx_params['web3'] - gas_limit = tx_params['gas_limit'] - gas_price_gwei = tx_params['gas_price_gwei'] - account_private_key = tx_params['eoa_private_key'] - tx = func.build_transaction({ - 'from': web3.eth.default_account.address, - 'chainId': web3.eth.chain_id, - 'nonce': get_nonce(web3), - 'gas': gas_limit, - 'gasPrice': web3.to_wei(gas_price_gwei, 'gwei') - }) - # validate_gas_estimation(web3, tx) - tx_receipt = sign_and_send_tx(web3, account_private_key, tx) - return tx_receipt - - -def sign_and_send_tx(web3, private_key, transaction): - try: - signed_tx = web3.eth.account.sign_transaction(transaction, private_key) - except Exception as e: - raise Exception(f"Failed to sign the transaction: {e}") - try: - tx_hash = web3.eth.send_raw_transaction(signed_tx.rawTransaction) - except Exception as e: - raise Exception(f"Failed to send the transaction: {e}") - try: - tx_receipt = web3.eth.wait_for_transaction_receipt(tx_hash) - except Exception as e: - raise Exception(f"Failed to wait for the transaction receipt: {e}") - return tx_receipt diff --git a/setup.py b/setup.py index 224e9f1..8cb7576 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ url='https://github.com/coti-io/coti-sdk-python', keywords='COTI SDK Privacy', install_requires=[ - 'pycryptodome==3.19.0', 'cryptography==3.4.8', 'eth-keys==0.4.0', 'eth-account==0.10.0', 'web3==6.11.2' + 'pycryptodome==3.19.0', 'cryptography==3.4.8', 'eth-keys>=0.4.0', 'eth-account>=0.13.1' ], python_requires=">=3.9", )