diff --git a/evm/utils/address.py b/evm/utils/address.py index 2cf9d71d09..57b43e35c7 100644 --- a/evm/utils/address.py +++ b/evm/utils/address.py @@ -10,6 +10,10 @@ private_key_to_public_key, ) +from evm.validation import ( + validate_raw_public_key +) + def force_bytes_to_address(value): trimmed_value = value[-20:] @@ -27,8 +31,5 @@ def private_key_to_address(private_key): def public_key_to_address(public_key): - if len(public_key) != 64: - raise ValueError( - "Unexpected public key format: {}. Public keys must be 64 bytes long and must not " - "include the fixed \x04 prefix".format(public_key)) + validate_raw_public_key(public_key) return keccak(public_key)[-20:] diff --git a/evm/utils/secp256k1.py b/evm/utils/secp256k1.py index af8b5f911a..8547c2ffe2 100644 --- a/evm/utils/secp256k1.py +++ b/evm/utils/secp256k1.py @@ -17,12 +17,13 @@ pad32, ) +from evm.validation import ( + validate_raw_public_key +) + def decode_public_key(public_key): - if len(public_key) != 64: - raise ValueError( - "Unexpected public key format: {}. Public keys must be 64 bytes long and must not " - "include the fixed \x04 prefix".format(public_key)) + validate_raw_public_key(public_key) left = big_endian_to_int(public_key[0:32]) right = big_endian_to_int(public_key[32:64]) return left, right diff --git a/evm/validation.py b/evm/validation.py index 7c2a04bf5d..f281e826b8 100644 --- a/evm/validation.py +++ b/evm/validation.py @@ -159,3 +159,9 @@ def validate_vm_block_numbers(vm_block_numbers): for block_number in vm_block_numbers: validate_block_number(block_number) + + +def validate_raw_public_key(value): + validate_is_bytes(value) + if len(value) != 64: + raise ValidationError("Unexpected public key format. Must be length 64 byte") diff --git a/tests/core/validation/test_validation.py b/tests/core/validation/test_validation.py index 0c0cde76dd..b525ffc1d9 100644 --- a/tests/core/validation/test_validation.py +++ b/tests/core/validation/test_validation.py @@ -23,6 +23,7 @@ validate_lt_secpk1n, validate_lt_secpk1n2, validate_multiple_of, + validate_raw_public_key, validate_stack_item, validate_uint256, validate_unique, @@ -31,6 +32,9 @@ ) +byte = b"\x00" + + @pytest.mark.parametrize( "value,is_valid", ( @@ -393,3 +397,19 @@ def test_validate_vm_block_numbers(vm_block_numbers, is_valid): else: with pytest.raises(ValidationError): validate_vm_block_numbers(vm_block_numbers) + + +@pytest.mark.parametrize( + "public_key_value,is_valid", + ( + (byte, False), + (("1" * 64), False), + (byte * 64, True), + ), +) +def test_validate_raw_public_key(public_key_value, is_valid): + if is_valid: + validate_raw_public_key(public_key_value) + else: + with pytest.raises(ValidationError): + validate_raw_public_key(public_key_value)