diff --git a/bitcoinutils/psbt.py b/bitcoinutils/psbt.py new file mode 100644 index 00000000..0909ad22 --- /dev/null +++ b/bitcoinutils/psbt.py @@ -0,0 +1,147 @@ +from bitcoinutils.transactions import ( + Transaction, +) + +from bitcoinutils.utils import ( + encode_varint, + vi_to_int +) + +MAGIC_BYTES = b"psbt\xff" +SEPARATOR = b'\x00' + +# Global key types +PSBT_GLOBAL_UNSIGNED_TX = b'\x00' +PSBT_GLOBAL_XPUB = b'\x01' +PSBT_GLOBAL_TX_VERSION = b'\x02' +PSBT_GLOBAL_FALLBACK_LOCKTIME = b'\x03' +PSBT_GLOBAL_INPUT_COUNT = b'\x04' +PSBT_GLOBAL_OUTPUT_COUNT = b'\x05' +PSBT_GLOBAL_TX_MODIFIABLE = b'\x06' +PSBT_GLOBAL_SP_ECDH_SHARE = b'\x07' +PSBT_GLOBAL_SP_DLEQ = b'\x08' +PSBT_GLOBAL_VERSION = b'\xFB' +PSBT_GLOBAL_PROPRIETARY = b'\xFC' + +# Per-input key types +PSBT_IN_NON_WITNESS_UTXO = b'\x00' +PSBT_IN_WITNESS_UTXO = b'\x01' +PSBT_IN_PARTIAL_SIG = b'\x02' +PSBT_IN_SIGHASH_TYPE = b'\x03' +PSBT_IN_REDEEM_SCRIPT = b'\x04' +PSBT_IN_WITNESS_SCRIPT = b'\x05' +PSBT_IN_BIP32_DERIVATION = b'\x06' +PSBT_IN_FINAL_SCRIPTSIG = b'\x07' +PSBT_IN_FINAL_SCRIPTWITNESS = b'\x08' + +# Per-output key types +PSBT_OUT_REDEEM_SCRIPT = b'\x00' +PSBT_OUT_WITNESS_SCRIPT = b'\x01' +PSBT_OUT_BIP32_DERIVATION = b'\x02' +PSBT_OUT_AMOUNT = b'\x03' +PSBT_OUT_SCRIPT = b'\x04' + + +class PSBT: + def __init__(self, maps: dict): + ''' + Parameters + ---------- + maps : dict + A dictionary with the keys 'global', 'input' and 'output' containing the corresponding maps.''' + self.maps = maps + #TODO: add checks to validate psbt (will be added in future PRs) + + @staticmethod + def serialize_key_val(key: bytes, val: bytes): + '''Serialize a key value pair, key, val should be bytes''' + return encode_varint(len(key)) + key + encode_varint(len(val)) + val + + @staticmethod + def parse_key_value(s): + """Parse a key-value pair from the PSBT stream.""" + # Read the first byte to determine the key length + key_length_bytes = s.read(1) + key_length, _ = vi_to_int(key_length_bytes) + # If key length is 0, return None (indicates a separator) + if key_length == 0: + return None, None + # Read the key + key = s.read(key_length) + + # Read the value length + val_length_bytes = s.read(1) + val_length, _ = vi_to_int(val_length_bytes) + # Read the value + val = s.read(val_length) + + return key, val + + def serialize(self): + psbt = MAGIC_BYTES + # Here we are including keytype and keydata in key, therefore serialize_key_val() works as intended + for key, val in sorted(self.maps['global'].items()): + psbt += self.serialize_key_val(key, val) + psbt += SEPARATOR + for inp in self.maps['input']: + for key, val in sorted(inp.items()): + psbt += self.serialize_key_val(key, val) + psbt += SEPARATOR + for out in self.maps['output']: + for key, val in sorted(out.items()): + psbt += self.serialize_key_val(key, val) + psbt += SEPARATOR + return psbt + + + @classmethod + def parse(cls, s): + if s.read(5) != MAGIC_BYTES: + raise ValueError('Invalid PSBT magic bytes') + maps = {'global': {}, 'input': [], 'output': []} + + globals = True #To check if paresed key value is from global map + input_ind = 0 + output_ind = 0 + + while globals or input_ind > 0 or output_ind > 0: + key, val = PSBT.parse_key_value(s) + + if globals: + if key is None: #Separator is reached indicating end of global map + globals = False + continue + + maps['global'][key] = val + + + if key == PSBT_GLOBAL_UNSIGNED_TX: #If unsigned transaction is found, intialize input and output maps + hex_val = val.hex() + transaction = Transaction.from_raw(hex_val) + input_ind = len(transaction.inputs) + output_ind = len(transaction.outputs) + # input_ind = 1 + # output_ind = 1 + maps['input'] = [{} for _ in range(input_ind)] + maps['output'] = [{} for _ in range(output_ind)] + + elif input_ind > 0: # Means input map is being parsed + if key is None: #Separator is reached; indicating end of the particular input map, there can be multiple input maps + input_ind -= 1 + continue + + ind = input_ind - len(maps['input']) #Get the index of the input being parsed + maps['input'][ind][key] = val + + elif output_ind > 0: # Means output map is being parsed + + if key is None: #Separator is reached; indicating end of the particular output map, there can be multiple output maps + output_ind -= 1 + continue + + ind = output_ind - len(maps['output']) #Get the index of the output being parsed + maps['output'][ind][key] = val + + return cls(maps) + + #TODO: Add methods to parse and serialize psbt as b64 and hex (will be added in future PRs) \ No newline at end of file diff --git a/tests/test_psbt.py b/tests/test_psbt.py new file mode 100644 index 00000000..5029fe85 --- /dev/null +++ b/tests/test_psbt.py @@ -0,0 +1,85 @@ +import unittest +from io import BytesIO +from bitcoinutils.psbt import PSBT, MAGIC_BYTES, SEPARATOR, PSBT_GLOBAL_UNSIGNED_TX +from bitcoinutils.transactions import Transaction +from bitcoinutils.utils import encode_varint + + +class TestPSBT(unittest.TestCase): + def setUp(self): + # Sample PSBT maps for testing + self.sample_maps = { + 'global': { + PSBT_GLOBAL_UNSIGNED_TX: bytes.fromhex( + "0100000001c3b5b9b07ec40d9e3f5edfa7e4f10b23bc653e5b6a5a1c79d1f4d232b3c6a29d000000006a473044022067e502e82d02e7a1a3b504897dfec4ea8df71a3b77cfe1b9cbf7d3f16a63642e02206e3b32b1e6b8f184a654bd22c6cb4a616274e0e44ed14e7f3e54d5e2d840cc6f012102a84c91d495bfecb17ea00e1dd6c634755643b95a09856c7cde4575a11b3a48e6ffffffff01a0860100000000001976a91489abcdefabbaabbaabbaabbaabbaabbaabbaabba88ac00000000" + ), + }, + 'input': [ + {b'\x00': b'\x01\x02\x03'}, # Example input map + ], + 'output': [ + {b'\x00': b'\x04\x05\x06'}, # Example output map + ] + } + self.psbt = PSBT(self.sample_maps) + + def test_serialize(self): + """Test if the PSBT object serializes correctly.""" + serialized = self.psbt.serialize() + + # Check if the serialized PSBT starts with the magic bytes + self.assertTrue(serialized.startswith(MAGIC_BYTES)) + + # Check if the global map is serialized correctly + for key, val in self.sample_maps['global'].items(): + encoded_key_val = encode_varint(len(key)) + key + encode_varint(len(val)) + val + self.assertIn(encoded_key_val, serialized) + + # Check if the input maps are serialized correctly + for inp in self.sample_maps['input']: + for key, val in inp.items(): + encoded_key_val = encode_varint(len(key)) + key + encode_varint(len(val)) + val + self.assertIn(encoded_key_val, serialized) + + # Check if the output maps are serialized correctly + for out in self.sample_maps['output']: + for key, val in out.items(): + encoded_key_val = encode_varint(len(key)) + key + encode_varint(len(val)) + val + self.assertIn(encoded_key_val, serialized) + + + def test_parse(self): + """Test if the PSBT object parses correctly.""" + serialized = self.psbt.serialize() + parsed_psbt = PSBT.parse(BytesIO(serialized)) + + # Check if the parsed PSBT matches the original maps + self.assertEqual(parsed_psbt.maps['global'], self.sample_maps['global']) + self.assertEqual(parsed_psbt.maps['input'], self.sample_maps['input']) + self.assertEqual(parsed_psbt.maps['output'], self.sample_maps['output']) + + def test_serialize_and_parse(self): + """Test if serialization and parsing are consistent.""" + serialized = self.psbt.serialize() + parsed_psbt = PSBT.parse(BytesIO(serialized)) + + # Serialize the parsed PSBT and compare with the original serialization + reserialized = parsed_psbt.serialize() + self.assertEqual(serialized, reserialized) + + def test_parse_invalid_magic_bytes(self): + """Test parsing with invalid magic bytes.""" + invalid_psbt = b"abcd" + self.psbt.serialize()[4:] # Replace magic bytes + with self.assertRaises(ValueError) as context: + PSBT.parse(BytesIO(invalid_psbt)) + self.assertEqual(str(context.exception), "Invalid PSBT magic bytes") + + def test_parse_missing_separator(self): + """Test parsing with missing separators.""" + serialized = self.psbt.serialize().replace(SEPARATOR, b"") # Remove separators + with self.assertRaises(Exception): # Replace with a specific exception if implemented + PSBT.parse(BytesIO(serialized)) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file