diff --git a/zcash_test_vectors/transaction.py b/zcash_test_vectors/transaction.py index 0c8f0f21..b0e641d0 100644 --- a/zcash_test_vectors/transaction.py +++ b/zcash_test_vectors/transaction.py @@ -5,7 +5,6 @@ Scalar as PallasScalar, ) from .orchard.sinsemilla import group_hash as pallas_group_hash -from .orchard_zsa.digests import NU7_TX_VERSION_BYTES from .sapling.generators import find_group_hash, SPENDING_KEY_BASE from .sapling.jubjub import ( Fq, @@ -508,8 +507,8 @@ def is_coinbase(self): def to_bytes(self, version_bytes, version_group_id, consensus_branch_id): ret = b'' ret += self.header_bytes(version_bytes, version_group_id, consensus_branch_id) - ret += self.transparent_bytes(version_bytes) - ret += self.sapling_bytes(version_bytes) + ret += self.transparent_bytes() + ret += self.sapling_bytes() return ret def header_bytes(self, version_bytes, version_group_id, consensus_branch_id): @@ -523,7 +522,7 @@ def header_bytes(self, version_bytes, version_group_id, consensus_branch_id): return ret - def transparent_bytes(self, version_bytes): + def transparent_bytes(self): ret = b'' # Transparent Transaction Fields ret += write_compact_size(len(self.vin)) @@ -532,14 +531,13 @@ def transparent_bytes(self, version_bytes): ret += write_compact_size(len(self.vout)) for x in self.vout: ret += bytes(x) - if version_bytes == NU7_TX_VERSION_BYTES: - for sighash_info in self.vSighashInfo: - ret += write_compact_size(len(sighash_info)) - ret += bytes(sighash_info) - + ret += self.transparent_sighash_info_bytes() return ret - def sapling_bytes(self, version_bytes): + def transparent_sighash_info_bytes(self): + raise NotImplementedError("The transparent_sighash_info_bytes method must be implemented in the child class.") + + def sapling_bytes(self): ret = b'' # Sapling Transaction Fields has_sapling = len(self.vSpendsSapling) + len(self.vOutputsSapling) > 0 @@ -558,20 +556,20 @@ def sapling_bytes(self, version_bytes): for desc in self.vSpendsSapling: # vSpendProofsSapling ret += bytes(desc.proof) for desc in self.vSpendsSapling: # vSpendAuthSigsSapling - if version_bytes == NU7_TX_VERSION_BYTES: - ret += write_compact_size(len(desc.spendAuthSigInfo)) - ret += bytes(desc.spendAuthSigInfo) - ret += bytes(desc.spendAuthSig) + ret += self.sapling_spend_auth_sig_bytes(desc) for desc in self.vOutputsSapling: # vOutputProofsSapling ret += bytes(desc.proof) if has_sapling: - if version_bytes == NU7_TX_VERSION_BYTES: - ret += write_compact_size(len(self.bindingSigSaplingInfo)) - ret += bytes(self.bindingSigSaplingInfo) - ret += bytes(self.bindingSigSapling) + ret += self.sapling_binding_sig_bytes() return ret + def sapling_spend_auth_sig_bytes(self, desc): + raise NotImplementedError("The sapling_spend_auth_sig_bytes method must be implemented in the child class.") + + def sapling_binding_sig_bytes(self): + raise NotImplementedError("The sapling_binding_sig_bytes method must be implemented in the child class.") + class TransactionV5(TransactionBase): def __init__(self, rand, consensus_branch_id): have_orchard = rand.bool() @@ -597,6 +595,16 @@ def __init__(self, rand, consensus_branch_id): def version_bytes(): return NU5_TX_VERSION_BYTES + def transparent_sighash_info_bytes(self): + # There are no such bytes for V5 transactions. + return b'' + + def sapling_spend_auth_sig_bytes(self, desc): + return bytes(desc.spendAuthSig) + + def sapling_binding_sig_bytes(self): + return bytes(self.bindingSigSapling) + def __bytes__(self): ret = b'' diff --git a/zcash_test_vectors/transaction_v6.py b/zcash_test_vectors/transaction_v6.py index 1598dce6..2624e9e0 100644 --- a/zcash_test_vectors/transaction_v6.py +++ b/zcash_test_vectors/transaction_v6.py @@ -198,6 +198,27 @@ def header_bytes(self, version_bytes, version_group_id, consensus_branch_id): ret += struct.pack('