Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 92 additions & 52 deletions bitcoinutils/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,74 +528,113 @@ class PublicKey:
returns the corresponding P2trAddress object
"""

def __init__(self, hex_str: str) -> None:
def __init__(self, hex_str: str = None, message: str = None, signature: bytes = None) -> None:
"""
Parameters
----------
hex_str : str
hex_str : str, optional
the public key in hex string

In case of generating public key from message and signature:-
message : str, optional
The original message that was signed
signature : bytes, optional
A 65-byte Bitcoin signature (1-byte recovery ID + 64-byte ECDSA signature).

Raises
------
TypeError
If first byte of public key (corresponding to SEC format) is
invalid.
If neither hex_str nor (message, signature) are provided
ValueError
If message is empty when attempting recovery
If signature is not exactly 65 bytes
If an invalid recovery ID is detected
"""
hex_str = hex_str.strip()
if hex_str:
hex_str = hex_str.strip()

# Normalize hex string by removing '0x' prefix and any whitespace
if hex_str.lower().startswith('0x'):
hex_str = hex_str[2:]
# Normalize hex string by removing '0x' prefix and any whitespace
if hex_str.lower().startswith('0x'):
hex_str = hex_str[2:]

# expects key as hex string - SEC format
first_byte_in_hex = hex_str[:2] # 2 hex chars = 1 byte
hex_bytes = h_to_b(hex_str)
# expects key as hex string - SEC format
first_byte_in_hex = hex_str[:2] # 2 hex chars = 1 byte
hex_bytes = h_to_b(hex_str)

taproot = False
taproot = False

# check if compressed or not
if len(hex_bytes) > 33:
# uncompressed - SEC format: 0x04 + x + y coordinates (x,y are 32 byte
# numbers)
# check if compressed or not
if len(hex_bytes) > 33:
# uncompressed - SEC format: 0x04 + x + y coordinates (x,y are 32 byte
# numbers)

# remove first byte and instantiate ecdsa key
self.key = VerifyingKey.from_string(hex_bytes[1:], curve=SECP256k1)
elif len(hex_bytes) > 31:
# key is either compressed or in x-only taproot format
# remove first byte and instantiate ecdsa key
self.key = VerifyingKey.from_string(hex_bytes[1:], curve=SECP256k1)
elif len(hex_bytes) > 31:
# key is either compressed or in x-only taproot format

# taproot public keys are exactly 32 bytes
if len(hex_bytes) == 32:
taproot = True
# taproot public keys are exactly 32 bytes
if len(hex_bytes) == 32:
taproot = True

# compressed - SEC FORMAT: 0x02|0x03 + x coordinate (if 02 then y
# is even else y is odd. Calculate y and then instantiate the ecdsa key
x_coord = int(hex_str[2:], 16)
# compressed - SEC FORMAT: 0x02|0x03 + x coordinate (if 02 then y
# is even else y is odd. Calculate y and then instantiate the ecdsa key
x_coord = int(hex_str[2:], 16)

# y = modulo_square_root( (x**3 + 7) mod p ) -- there will be 2 y values
y_values = sqrt_mod(
(x_coord**3 + 7) % Secp256k1Params._p, Secp256k1Params._p, True
)
# y = modulo_square_root( (x**3 + 7) mod p ) -- there will be 2 y values
y_values = sqrt_mod(
(x_coord**3 + 7) % Secp256k1Params._p, Secp256k1Params._p, True
)

assert y_values is not None
# check SEC format's first byte to determine which of the 2 values to use
if first_byte_in_hex == "02" or taproot:
# y is the even value
if y_values[0] % 2 == 0: # type: ignore
y_coord = y_values[0] # type: ignore
assert y_values is not None
# check SEC format's first byte to determine which of the 2 values to use
if first_byte_in_hex == "02" or taproot:
# y is the even value
if y_values[0] % 2 == 0: # type: ignore
y_coord = y_values[0] # type: ignore
else:
y_coord = y_values[1] # type: ignore
elif first_byte_in_hex == "03":
# y is the odd value
if y_values[0] % 2 == 0: # type: ignore
y_coord = y_values[1] # type: ignore
else:
y_coord = y_values[0] # type: ignore
else:
y_coord = y_values[1] # type: ignore
elif first_byte_in_hex == "03":
# y is the odd value
if y_values[0] % 2 == 0: # type: ignore
y_coord = y_values[1] # type: ignore
else:
y_coord = y_values[0] # type: ignore
else:
raise TypeError("Invalid SEC compressed format")

uncompressed_hex = f"{x_coord:064x}{y_coord:064x}"
uncompressed_hex_bytes = h_to_b(uncompressed_hex)
self.key = VerifyingKey.from_string(uncompressed_hex_bytes, curve=SECP256k1)
raise TypeError("Invalid SEC compressed format")

uncompressed_hex = f"{x_coord:064x}{y_coord:064x}"
uncompressed_hex_bytes = h_to_b(uncompressed_hex)
self.key = VerifyingKey.from_string(uncompressed_hex_bytes, curve=SECP256k1)
elif message or signature:
if not message:
raise ValueError("Empty message provided for public key recovery.")

if(len(signature) != 65):
raise ValueError("Invalid signature length, must be exactly 65 bytes")

# The compressed signature is of the format: recovery_id (1 byte) | r (32 bytes) | s (32 bytes)
# We subtract the prefix(27) for uncompressed signatures and an additional 4 (31) for compressed signatures to get the recovery id
recovery_id = signature[0] - 31
if not (0 <= recovery_id <= 3): # A valid recovery ID is between 0 and 3
raise ValueError(f"Invalid recovery ID: expected 31-34, got {signature[0]}")

signature = signature[1:] #Remove recovery id from signature

# All bitcoin signatures include the magic prefix. It is just a string
# added to the message to distinguish Bitcoin-specific messages.
message_magic = add_magic_prefix(message)
# create message digest
message_digest = hashlib.sha256(hashlib.sha256(message_magic).digest()).digest()

recovered_keys = VerifyingKey.from_public_key_recovery_with_digest(
signature, message_digest, curve=SECP256k1, hashfunc = hashlib.sha256, sigdecode=sigdecode_string
)
self.key = recovered_keys[recovery_id]
else:
raise TypeError("Either 'hex_str' or ('message', 'signature') must be provided.")

@classmethod
def from_hex(cls, hex_str: str) -> PublicKey:
Expand Down Expand Up @@ -665,11 +704,12 @@ def is_y_even(self) -> bool:
return y % 2 == 0

@classmethod
def from_message_signature(cls, signature):
# TODO implement (add signature=None in __init__, etc.)
# TODO plus does this apply to DER signatures as well?
# return cls(signature=signature)
raise BaseException("NO-OP!")
def from_message_signature(cls, message, signature):
"""Recovers a public key from a Bitcoin-signed message and a 65-byte compressed signature.
"""
#Note: Only works for compressed signatures because DER encoding does not contain the recovery id
return cls(message=message, signature=signature)
# raise BaseException("NO-OP!")

@classmethod
def verify_message(cls, address: str, signature: str, message: str) -> bool:
Expand Down
43 changes: 41 additions & 2 deletions tests/test_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)
from bitcoinutils.script import Script
from bitcoinutils.hdwallet import HDWallet
from base64 import b64decode


class TestPrivateKeys(unittest.TestCase):
Expand Down Expand Up @@ -72,6 +73,13 @@ def setUp(self):
b"\x08\xa8\xfd\x17\xb4H\xa6\x85T\x19\x9cG\xd0\x8f\xfb\x10\xd4\xb8"
)
self.address = "1EHNa6Q4Jz2uvNExL497mE43ikXhwF6kZm"

# Message public key recovery tests
self.valid_message = "Hello, Bitcoin!"
# 65-byte Bitcoin signature (1-byte recovery ID + 64-byte ECDSA signature)
self.valid_signature = b'\x1f\x0c\xfc\xd8V\xec27)\xa7\xfc\x02:\xda\xcfT\xb2*\x02\x16.\xe2s\x7f\x18[&^\xb3e\xee3"KN\xfct\x011Z[\x05\xb5\xea\n!\xe8\xce\x9em\x89/\xf2\xa0\x15\x83{\x7f\x9e\xba+\xb4\xf8&\x15'
# Known valid public key corresponding to the message + signature
self.expected_public_key = '02649abc7094d2783670255073ccfd132677555ca84045c5a005611f25ef51fdbf'

def test_pubkey_creation(self):
pub1 = PublicKey(self.public_key_hex)
Expand All @@ -98,6 +106,38 @@ def test_pubkey_to_hash160(self):
def test_pubkey_x_only(self):
pub = PublicKey(self.public_key_hex)
self.assertEqual(pub.to_x_only_hex(), self.public_key_hex[2:66])

#Tests for PublicKey recovery from message and signature
def test_public_key_recovery_valid(self):
"""Test successful public key recovery from a valid message and signature"""
pubkey = PublicKey(message=self.valid_message, signature=self.valid_signature)
self.assertEqual(pubkey.key.to_string("compressed").hex(), self.expected_public_key)

def test_invalid_signature_length(self):
"""Test handling of invalid signature length (not 65 bytes)"""
short_signature = self.valid_signature[:60] # Truncate signature to 60 bytes
with self.assertRaises(ValueError) as context:
PublicKey(message=self.valid_message, signature=short_signature)
self.assertEqual(str(context.exception), "Invalid signature length, must be exactly 65 bytes")

def test_invalid_recovery_id(self):
"""Test handling of an invalid recovery ID"""
invalid_signature = bytes([50]) + self.valid_signature[1:] # Modify recovery ID to 50
with self.assertRaises(ValueError) as context:
PublicKey(message=self.valid_message, signature=invalid_signature)
self.assertIn("Invalid recovery ID", str(context.exception))

def test_missing_parameters(self):
"""Test that missing both hex_str and (message, signature) raises an error"""
with self.assertRaises(TypeError) as context:
PublicKey()
self.assertEqual(str(context.exception), "Either 'hex_str' or ('message', 'signature') must be provided.")

def test_empty_message(self):
"""Test handling of an empty message for public key recovery"""
with self.assertRaises(ValueError) as context:
PublicKey(message="", signature=self.valid_signature)
self.assertEqual(str(context.exception), "Empty message provided for public key recovery.")


class TestP2pkhAddresses(unittest.TestCase):
Expand Down Expand Up @@ -311,7 +351,6 @@ def test_legacy_address_from_mnemonic(self):
hdw.from_path("m/44'/1'/0'/0/3")
address = hdw.get_private_key().get_public_key().get_address()
self.assertTrue(address.to_string(), self.legacy_address_m_44_1h_0h_0_3)



if __name__ == "__main__":
unittest.main()