Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 25 additions & 27 deletions coti/crypto_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives.asymmetric import rsa
from eth_keys import keys
from .types import ItString, ItUint

block_size = AES.block_size
address_size = 20
Expand All @@ -13,17 +14,17 @@
key_size = 32


def encrypt(key, plaintext):
def encrypt(user_aes_key: bytes, plaintext: int):
# Ensure plaintext is smaller than 128 bits (16 bytes)
if len(plaintext) > block_size:
raise ValueError("Plaintext size must be 128 bits or smaller.")

# Ensure key size is 128 bits (16 bytes)
if len(key) != block_size:
if len(user_aes_key) != block_size:
raise ValueError("Key size must be 128 bits.")

# Create a new AES cipher block using the provided key
cipher = AES.new(key, AES.MODE_ECB)
cipher = AES.new(user_aes_key, AES.MODE_ECB)

# Generate a random value 'r' of the same length as the block size
r = get_random_bytes(block_size)
Expand All @@ -40,20 +41,20 @@ def encrypt(key, plaintext):
return ciphertext, r


def decrypt(key, r, ciphertext):
def decrypt(user_aes_key: bytes, r: bytes, ciphertext: bytes):
if len(ciphertext) != block_size:
raise ValueError("Ciphertext size must be 128 bits.")

# Ensure key size is 128 bits (16 bytes)
if len(key) != block_size:
if len(user_aes_key) != block_size:
raise ValueError("Key size must be 128 bits.")

# Ensure random size is 128 bits (16 bytes)
if len(r) != block_size:
raise ValueError("Random size must be 128 bits.")

# Create a new AES cipher block using the provided key
cipher = AES.new(key, AES.MODE_ECB)
cipher = AES.new(user_aes_key, AES.MODE_ECB)

# Encrypt the random value 'r' using AES in ECB mode
encrypted_r = cipher.encrypt(r)
Expand All @@ -71,14 +72,14 @@ def generate_aes_key():
return key


def sign_input_text(sender, addr, function_selector, ct, key):
def sign_input_text(sender_address: str, contract_address: str, function_selector: str, ct, key):
function_selector_bytes = bytes.fromhex(function_selector[2:])

# Ensure all input sizes are the correct length
if len(sender) != address_size:
raise ValueError(f"Invalid sender address length: {len(sender)} bytes, must be {address_size} bytes")
if len(addr) != address_size:
raise ValueError(f"Invalid contract address length: {len(addr)} bytes, must be {address_size} bytes")
if len(sender_address) != address_size:
raise ValueError(f"Invalid sender address length: {len(sender_address)} bytes, must be {address_size} bytes")
if len(contract_address) != address_size:
raise ValueError(f"Invalid contract address length: {len(contract_address)} bytes, must be {address_size} bytes")
if len(function_selector_bytes) != function_selector_size:
raise ValueError(f"Invalid signature size: {len(function_selector_bytes)} bytes, must be {function_selector_size} bytes")
if len(ct) != ct_size:
Expand All @@ -88,7 +89,7 @@ def sign_input_text(sender, addr, function_selector, ct, key):
raise ValueError(f"Invalid key length: {len(key)} bytes, must be {key_size} bytes")

# Create the message to be signed by appending all inputs
message = sender + addr + function_selector_bytes + ct
message = sender_address + contract_address + function_selector_bytes + ct

return sign(message, key)

Expand All @@ -101,19 +102,16 @@ def sign(message, key):
return signature


def build_input_text(plaintext, user_aes_key, sender, contract, function_selector, signing_key):
sender_address_bytes = bytes.fromhex(sender.address[2:])
contract_address_bytes = bytes.fromhex(contract.address[2:])

def build_input_text(plaintext: int, user_aes_key: str, sender_address: str, contract_address: str, function_selector: str, signing_key: str) -> ItUint:
# Convert the integer to a byte slice with size aligned to 8.
plaintext_bytes = plaintext.to_bytes((plaintext.bit_length() + 7) // 8, 'big')

# Encrypt the plaintext with the user's AES key
ciphertext, r = encrypt(user_aes_key, plaintext_bytes)
ciphertext, r = encrypt(bytes.fromhex(user_aes_key), plaintext_bytes)
ct = ciphertext + r

# Sign the message
signature = sign_input_text(sender_address_bytes, contract_address_bytes, function_selector, ct, signing_key)
signature = sign_input_text(bytes.fromhex(sender_address[2:]), bytes.fromhex(contract_address[2:]), function_selector, ct, signing_key)

# Convert the ct to an integer
int_cipher_text = int.from_bytes(ct, byteorder='big')
Expand All @@ -124,7 +122,7 @@ def build_input_text(plaintext, user_aes_key, sender, contract, function_selecto
}


def build_string_input_text(plaintext, user_aes_key, sender, contract, function_selector, signing_key):
def build_string_input_text(plaintext: int, user_aes_key: str, sender_address: str, contract_address: str, function_selector: str, signing_key: str) -> ItString:
input_text = {
'ciphertext': {
'value': []
Expand All @@ -142,8 +140,8 @@ def build_string_input_text(plaintext, user_aes_key, sender, contract, function_
it_int = build_input_text(
int.from_bytes(byte_arr, 'big'),
user_aes_key,
sender,
contract,
sender_address,
contract_address,
function_selector,
signing_key
)
Expand All @@ -154,7 +152,7 @@ def build_string_input_text(plaintext, user_aes_key, sender, contract, function_
return input_text


def decrypt_uint(ciphertext, user_key):
def decrypt_uint(ciphertext: int, user_aes_key: str) -> int:
# Convert ct to bytes (big-endian)
byte_array = ciphertext.to_bytes(32, byteorder='big')

Expand All @@ -163,15 +161,15 @@ def decrypt_uint(ciphertext, user_key):
r = byte_array[block_size:]

# Decrypt the cipher
decrypted_message = decrypt(user_key, r, cipher)
decrypted_message = decrypt(bytes.fromhex(user_aes_key), r, cipher)

# Print the decrypted cipher
decrypted_uint = int.from_bytes(decrypted_message, 'big')

return decrypted_uint


def decrypt_string(ciphertext, user_key):
def decrypt_string(ciphertext: int, user_aes_key: str) -> str:
if 'value' in ciphertext or hasattr(ciphertext, 'value'): # format when reading ciphertext from an event
__ciphertext = ciphertext['value']
elif isinstance(ciphertext, tuple): # format when reading ciphertext from state variable
Expand All @@ -182,7 +180,7 @@ def decrypt_string(ciphertext, user_key):
decrypted_string = ""

for value in __ciphertext:
decrypted = decrypt_uint(value, user_key)
decrypted = decrypt_uint(value, user_aes_key)

byte_length = (decrypted.bit_length() + 7) // 8 # calculate the byte length

Expand Down Expand Up @@ -219,7 +217,7 @@ def generate_rsa_keypair():
return private_key_bytes, public_key_bytes


def decrypt_rsa(private_key_bytes, ciphertext):
def decrypt_rsa(private_key_bytes: bytes, ciphertext: bytes):
# Load private key
private_key = serialization.load_der_private_key(private_key_bytes, password=None)
# Decrypt ciphertext
Expand All @@ -235,7 +233,7 @@ def decrypt_rsa(private_key_bytes, ciphertext):

#This function recovers a user's key by decrypting two encrypted key shares with the given private key,
#and then XORing the two key shares together.
def recover_user_key(private_key_bytes, encrypted_key_share0, encrypted_key_share1):
def recover_user_key(private_key_bytes: bytes, encrypted_key_share0: bytes, encrypted_key_share1: bytes):
key_share0 = decrypt_rsa(private_key_bytes, encrypted_key_share0)
key_share1 = decrypt_rsa(private_key_bytes, encrypted_key_share1)

Expand Down
20 changes: 20 additions & 0 deletions coti/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import List, TypeAlias, TypedDict

CtBool: TypeAlias = int

CtUint: TypeAlias = int

class CtString(TypedDict):
value: List[int]

class ItBool(TypedDict):
ciphertext: int
signature: bytes

class ItUint(TypedDict):
ciphertext: int
signature: bytes

class ItString(TypedDict):
ciphertext: CtString
signature: List[bytes]
149 changes: 0 additions & 149 deletions coti/utils.py

This file was deleted.

2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
url='https://github.com/coti-io/coti-sdk-python',
keywords='COTI SDK Privacy',
install_requires=[
'pycryptodome==3.19.0', 'cryptography==3.4.8', 'eth-keys==0.4.0', 'eth-account==0.10.0', 'web3==6.11.2'
'pycryptodome==3.19.0', 'cryptography==3.4.8', 'eth-keys>=0.4.0', 'eth-account>=0.13.1'
],
python_requires=">=3.9",
)
Loading