diff --git a/hathor/transaction/exceptions.py b/hathor/transaction/exceptions.py index 2d1bfbda8..e44419660 100644 --- a/hathor/transaction/exceptions.py +++ b/hathor/transaction/exceptions.py @@ -74,6 +74,10 @@ class ConflictingInputs(TxValidationError): """Inputs in the tx are spending the same output""" +class OutputNotSelected(TxValidationError): + """At least one output is not selected for signing by some input.""" + + class TooManyOutputs(TxValidationError): """More than 256 outputs""" @@ -202,3 +206,25 @@ class VerifyFailed(ScriptError): class TimeLocked(ScriptError): """Transaction is invalid because it is time locked""" + + +class InputNotSelectedError(ScriptError): + """Raised when an input does not select itself for signing in its script.""" + + +class MaxInputsExceededError(ScriptError): + """The transaction has more inputs than the maximum configured in the script.""" + + +class MaxOutputsExceededError(ScriptError): + """The transaction has more outputs than the maximum configured in the script.""" + + +class InputsOutputsLimitModelInvalid(ScriptError): + """ + Raised when the inputs outputs limit model could not be constructed from the arguments provided in the script. + """ + + +class CustomSighashModelInvalid(ScriptError): + """Raised when the sighash model could not be constructed from the arguments provided in the script.""" diff --git a/hathor/transaction/scripts/execute.py b/hathor/transaction/scripts/execute.py index 23109afbc..8fca1fc1a 100644 --- a/hathor/transaction/scripts/execute.py +++ b/hathor/transaction/scripts/execute.py @@ -13,17 +13,33 @@ # limitations under the License. import struct -from typing import NamedTuple, Optional, Union +from dataclasses import dataclass +from typing import TYPE_CHECKING, NamedTuple, Optional, Union +from hathor.conf.get_settings import get_global_settings from hathor.transaction import BaseTransaction, Transaction, TxInput from hathor.transaction.exceptions import DataIndexError, FinalStackInvalid, InvalidScriptError, OutOfData +if TYPE_CHECKING: + from hathor.transaction.scripts.script_context import ScriptContext + -class ScriptExtras(NamedTuple): +@dataclass(slots=True, frozen=True, kw_only=True) +class ScriptExtras: + """ + A simple container for auxiliary data that may be used during execution of scripts. + """ tx: Transaction - txin: TxInput + input_index: int spent_tx: BaseTransaction + @property + def txin(self) -> TxInput: + return self.tx.inputs[self.input_index] + + def __post_init__(self) -> None: + assert self.txin.tx_id == self.spent_tx.hash + # XXX: Because the Stack is a heterogeneous list of bytes and int, and some OPs only work for when the stack has some # or the other type, there are many places that require an assert to prevent the wrong type from being used, @@ -39,7 +55,7 @@ class OpcodePosition(NamedTuple): position: int -def execute_eval(data: bytes, log: list[str], extras: ScriptExtras) -> None: +def execute_eval(data: bytes, log: list[str], extras: ScriptExtras) -> 'ScriptContext': """ Execute eval from data executing opcode methods :param data: data to be evaluated that contains data and opcodes @@ -56,8 +72,9 @@ def execute_eval(data: bytes, log: list[str], extras: ScriptExtras) -> None: """ from hathor.transaction.scripts.opcode import Opcode, execute_op_code from hathor.transaction.scripts.script_context import ScriptContext + settings = get_global_settings() stack: Stack = [] - context = ScriptContext(stack=stack, logs=log, extras=extras) + context = ScriptContext(settings=settings, stack=stack, logs=log, extras=extras) data_len = len(data) pos = 0 while pos < data_len: @@ -70,6 +87,8 @@ def execute_eval(data: bytes, log: list[str], extras: ScriptExtras) -> None: evaluate_final_stack(stack, log) + return context + def evaluate_final_stack(stack: Stack, log: list[str]) -> None: """ Checks the final state of the stack. @@ -88,25 +107,20 @@ def evaluate_final_stack(stack: Stack, log: list[str]) -> None: raise FinalStackInvalid('\n'.join(log)) -def script_eval(tx: Transaction, txin: TxInput, spent_tx: BaseTransaction) -> None: - """Evaluates the output script and input data according to - a very limited subset of Bitcoin's scripting language. - - :param tx: the transaction being validated, the 'owner' of the input data - :type tx: :py:class:`hathor.transaction.Transaction` - - :param txin: transaction input being evaluated - :type txin: :py:class:`hathor.transaction.TxInput` - - :param spent_tx: the transaction referenced by the input - :type spent_tx: :py:class:`hathor.transaction.BaseTransaction` +def script_eval(*, tx: Transaction, spent_tx: BaseTransaction, input_index: int) -> 'ScriptContext': + """ + Evaluates the output script and input data according to a very limited subset of Bitcoin's scripting language. + Raises ScriptError if script verification fails - :raises ScriptError: if script verification fails + Args: + tx: the transaction being validated, the 'owner' of the input data + spent_tx: the transaction referenced by the input + input_index: index of the transaction input being evaluated """ - input_data = txin.data - output_script = spent_tx.outputs[txin.index].script + extras = ScriptExtras(tx=tx, input_index=input_index, spent_tx=spent_tx) + input_data = extras.txin.data + output_script = spent_tx.outputs[extras.txin.index].script log: list[str] = [] - extras = ScriptExtras(tx=tx, txin=txin, spent_tx=spent_tx) from hathor.transaction.scripts import MultiSig if MultiSig.re_match.search(output_script): @@ -115,17 +129,17 @@ def script_eval(tx: Transaction, txin: TxInput, spent_tx: BaseTransaction) -> No # we can't use input_data + output_script because it will end with an invalid stack # i.e. the signatures will still be on the stack after ouput_script is executed redeem_script_pos = MultiSig.get_multisig_redeem_script_pos(input_data) - full_data = txin.data[redeem_script_pos:] + output_script + full_data = extras.txin.data[redeem_script_pos:] + output_script execute_eval(full_data, log, extras) # Second, we need to validate that the signatures on the input_data solves the redeem_script # we pop and append the redeem_script to the input_data and execute it multisig_data = MultiSig.get_multisig_data(extras.txin.data) - execute_eval(multisig_data, log, extras) + return execute_eval(multisig_data, log, extras) else: # merge input_data and output_script full_data = input_data + output_script - execute_eval(full_data, log, extras) + return execute_eval(full_data, log, extras) def decode_opn(opcode: int) -> int: diff --git a/hathor/transaction/scripts/opcode.py b/hathor/transaction/scripts/opcode.py index 460c66821..22230c25f 100644 --- a/hathor/transaction/scripts/opcode.py +++ b/hathor/transaction/scripts/opcode.py @@ -16,6 +16,7 @@ import struct from enum import IntEnum +import pydantic from cryptography.exceptions import InvalidSignature from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.asymmetric import ec @@ -28,9 +29,14 @@ is_pubkey_compressed, ) from hathor.transaction.exceptions import ( + CustomSighashModelInvalid, EqualVerifyFailed, + InputNotSelectedError, + InputsOutputsLimitModelInvalid, InvalidScriptError, InvalidStackData, + MaxInputsExceededError, + MaxOutputsExceededError, MissingStackItems, OracleChecksigFailed, ScriptError, @@ -39,6 +45,8 @@ ) from hathor.transaction.scripts.execute import Stack, binary_to_int, decode_opn, get_data_value, get_script_op from hathor.transaction.scripts.script_context import ScriptContext +from hathor.transaction.scripts.sighash import InputsOutputsLimit, SighashBitmask +from hathor.transaction.util import bytes_to_int class Opcode(IntEnum): @@ -72,6 +80,9 @@ class Opcode(IntEnum): OP_DATA_GREATERTHAN = 0xC1 OP_FIND_P2PKH = 0xD0 OP_DATA_MATCH_VALUE = 0xD1 + OP_SIGHASH_BITMASK = 0xE0 + OP_SIGHASH_RANGE = 0xE1 + OP_MAX_INPUTS_OUTPUTS = 0xE2 @classmethod def is_pushdata(cls, opcode: int) -> bool: @@ -249,7 +260,8 @@ def op_checksig(context: ScriptContext) -> None: # pubkey is not compressed public key raise ScriptError('OP_CHECKSIG: pubkey is not a public key') from e try: - public_key.verify(signature, context.extras.tx.get_sighash_all_data(), ec.ECDSA(hashes.SHA256())) + sighash_data = context.get_tx_sighash_data(context.extras.tx) + public_key.verify(signature, sighash_data, ec.ECDSA(hashes.SHA256())) # valid, push true to stack context.stack.append(1) except InvalidSignature: @@ -583,7 +595,7 @@ def op_checkmultisig(context: ScriptContext) -> None: while pubkey_index < len(pubkeys): pubkey = pubkeys[pubkey_index] new_stack = [signature, pubkey] - op_checksig(ScriptContext(stack=new_stack, logs=context.logs, extras=context.extras)) + op_checksig(ScriptContext(stack=new_stack, logs=context.logs, extras=context.extras, settings=settings)) result = new_stack.pop() pubkey_index += 1 if result == 1: @@ -617,6 +629,59 @@ def op_integer(opcode: int, stack: Stack) -> None: raise ScriptError(e) from e +def op_sighash_bitmask(context: ScriptContext) -> None: + """Pop two items from the stack, constructing a sighash bitmask and setting it in the script context.""" + if len(context.stack) < 2: + raise MissingStackItems(f'OP_SIGHASH_BITMASK: expected 2 elements on stack, has {len(context.stack)}') + + outputs = context.stack.pop() + inputs = context.stack.pop() + assert isinstance(inputs, bytes) + assert isinstance(outputs, bytes) + + try: + sighash = SighashBitmask( + inputs=bytes_to_int(inputs), + outputs=bytes_to_int(outputs) + ) + except pydantic.ValidationError as e: + raise CustomSighashModelInvalid('Could not construct sighash bitmask.') from e + + if context.extras.input_index not in sighash.get_input_indexes(): + raise InputNotSelectedError( + f'Input at index {context.extras.input_index} must select itself when using a custom sighash.' + ) + + context.set_sighash(sighash) + + +def op_max_inputs_outputs(context: ScriptContext) -> None: + """Pop two items from the stack, constructing an inputs and outputs limit and setting it in the script context.""" + if len(context.stack) < 2: + raise MissingStackItems(f'OP_MAX_INPUTS_OUTPUTS: expected 2 elements on stack, has {len(context.stack)}') + + max_outputs = context.stack.pop() + max_inputs = context.stack.pop() + assert isinstance(max_inputs, bytes) + assert isinstance(max_outputs, bytes) + + try: + limit = InputsOutputsLimit( + max_inputs=bytes_to_int(max_inputs), + max_outputs=bytes_to_int(max_outputs) + ) + except pydantic.ValidationError as e: + raise InputsOutputsLimitModelInvalid("Could not construct inputs and outputs limits.") from e + + tx_inputs_len = len(context.extras.tx.inputs) + if tx_inputs_len > limit.max_inputs: + raise MaxInputsExceededError(f'Maximum number of inputs exceeded ({tx_inputs_len} > {limit.max_inputs}).') + + tx_outputs_len = len(context.extras.tx.outputs) + if tx_outputs_len > limit.max_outputs: + raise MaxOutputsExceededError(f'Maximum number of outputs exceeded ({tx_outputs_len} > {limit.max_outputs}).') + + def execute_op_code(opcode: Opcode, context: ScriptContext) -> None: """ Execute a function opcode. @@ -625,6 +690,8 @@ def execute_op_code(opcode: Opcode, context: ScriptContext) -> None: opcode: the opcode to be executed. context: the script context to be manipulated. """ + if not is_opcode_valid(opcode): + raise ScriptError(f'Opcode "{opcode.name}" is invalid.') context.logs.append(f'Executing function opcode {opcode.name} ({hex(opcode.value)})') match opcode: case Opcode.OP_DUP: op_dup(context) @@ -639,4 +706,26 @@ def execute_op_code(opcode: Opcode, context: ScriptContext) -> None: case Opcode.OP_DATA_MATCH_VALUE: op_data_match_value(context) case Opcode.OP_CHECKDATASIG: op_checkdatasig(context) case Opcode.OP_FIND_P2PKH: op_find_p2pkh(context) + case Opcode.OP_SIGHASH_BITMASK: op_sighash_bitmask(context) + case Opcode.OP_MAX_INPUTS_OUTPUTS: op_max_inputs_outputs(context) case _: raise ScriptError(f'unknown opcode: {opcode}') + + +def is_opcode_valid(opcode: Opcode) -> bool: + """Return whether an opcode is valid, that is, it's currently enabled.""" + valid_opcodes = [ + Opcode.OP_DUP, + Opcode.OP_EQUAL, + Opcode.OP_EQUALVERIFY, + Opcode.OP_CHECKSIG, + Opcode.OP_HASH160, + Opcode.OP_GREATERTHAN_TIMESTAMP, + Opcode.OP_CHECKMULTISIG, + Opcode.OP_DATA_STREQUAL, + Opcode.OP_DATA_GREATERTHAN, + Opcode.OP_DATA_MATCH_VALUE, + Opcode.OP_CHECKDATASIG, + Opcode.OP_FIND_P2PKH, + ] + + return opcode in valid_opcodes diff --git a/hathor/transaction/scripts/p2pkh.py b/hathor/transaction/scripts/p2pkh.py index 52812680c..76cc412c5 100644 --- a/hathor/transaction/scripts/p2pkh.py +++ b/hathor/transaction/scripts/p2pkh.py @@ -20,6 +20,7 @@ from hathor.transaction.scripts.construct import get_pushdata, re_compile from hathor.transaction.scripts.hathor_script import HathorScript from hathor.transaction.scripts.opcode import Opcode +from hathor.transaction.scripts.sighash import InputsOutputsLimit, SighashBitmask class P2PKH(BaseScript): @@ -91,7 +92,14 @@ def create_output_script(cls, address: bytes, timelock: Optional[Any] = None) -> return s.data @classmethod - def create_input_data(cls, public_key_bytes: bytes, signature: bytes) -> bytes: + def create_input_data( + cls, + public_key_bytes: bytes, + signature: bytes, + *, + sighash: SighashBitmask | None = None, + inputs_outputs_limit: InputsOutputsLimit | None = None + ) -> bytes: """ :param private_key: key corresponding to the address we want to spend tokens from :type private_key: :py:class:`cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey` @@ -99,8 +107,20 @@ def create_input_data(cls, public_key_bytes: bytes, signature: bytes) -> bytes: :rtype: bytes """ s = HathorScript() + + if sighash: + s.pushData(sighash.inputs) + s.pushData(sighash.outputs) + s.addOpcode(Opcode.OP_SIGHASH_BITMASK) + + if inputs_outputs_limit: + s.pushData(inputs_outputs_limit.max_inputs) + s.pushData(inputs_outputs_limit.max_outputs) + s.addOpcode(Opcode.OP_MAX_INPUTS_OUTPUTS) + s.pushData(signature) s.pushData(public_key_bytes) + return s.data @classmethod diff --git a/hathor/transaction/scripts/script_context.py b/hathor/transaction/scripts/script_context.py index 925a881f1..a414671b8 100644 --- a/hathor/transaction/scripts/script_context.py +++ b/hathor/transaction/scripts/script_context.py @@ -12,14 +12,58 @@ # See the License for the specific language governing permissions and # limitations under the License. +import hashlib + +from typing_extensions import assert_never + +from hathor.conf.settings import HathorSettings +from hathor.transaction import Transaction +from hathor.transaction.exceptions import ScriptError from hathor.transaction.scripts.execute import ScriptExtras, Stack +from hathor.transaction.scripts.sighash import SighashAll, SighashBitmask, SighashType class ScriptContext: """A context to be manipulated during script execution. A separate instance must be used for each script.""" - __slots__ = ('stack', 'logs', 'extras') + __slots__ = ('stack', 'logs', 'extras', '_settings', '_sighash') - def __init__(self, *, stack: Stack, logs: list[str], extras: ScriptExtras) -> None: + def __init__(self, *, stack: Stack, logs: list[str], extras: ScriptExtras, settings: HathorSettings) -> None: self.stack = stack self.logs = logs self.extras = extras + self._settings = settings + self._sighash: SighashType = SighashAll() + + def set_sighash(self, sighash: SighashType) -> None: + """ + Set a Sighash type in this context. + It can only be set once, that is, a script cannot use more than one sighash type. + """ + if type(self._sighash) is not SighashAll: + raise ScriptError('Cannot modify sighash after it is already set.') + + self._sighash = sighash + + def get_tx_sighash_data(self, tx: Transaction) -> bytes: + """ + Return the sighash data for a tx, depending on the sighash type set in this context. + Must be used when verifying signatures during script execution. + """ + match self._sighash: + case SighashAll(): + return tx.get_sighash_all_data() + case SighashBitmask(): + data = tx.get_custom_sighash_data(self._sighash) + return hashlib.sha256(data).digest() + case _: + assert_never(self._sighash) + + def get_selected_outputs(self) -> set[int]: + """Get a set with all output indexes selected (that is, signed) in this context.""" + match self._sighash: + case SighashAll(): + return set(range(self._settings.MAX_NUM_OUTPUTS)) + case SighashBitmask(): + return set(self._sighash.get_output_indexes()) + case _: + assert_never(self._sighash) diff --git a/hathor/transaction/scripts/sighash.py b/hathor/transaction/scripts/sighash.py new file mode 100644 index 000000000..01565b72e --- /dev/null +++ b/hathor/transaction/scripts/sighash.py @@ -0,0 +1,69 @@ +# Copyright 2023 Hathor Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import TypeAlias + +from pydantic import Field +from typing_extensions import override + +from hathor.utils.pydantic import BaseModel + + +@dataclass(frozen=True, slots=True) +class SighashAll: + """A model representing the sighash all, which is the default sighash type.""" + pass + + +class CustomSighash(ABC, BaseModel): + """An interface to be implemented by custom sighash models.""" + @abstractmethod + def get_input_indexes(self) -> list[int]: + """Return a list of input indexes selected by this sighash.""" + raise NotImplementedError + + @abstractmethod + def get_output_indexes(self) -> list[int]: + """Return a list of output indexes selected by this sighash.""" + raise NotImplementedError + + +class SighashBitmask(CustomSighash): + """A model representing the sighash bitmask type config.""" + inputs: int = Field(ge=0x01, le=0xFF) + outputs: int = Field(ge=0x00, le=0xFF) + + @override + def get_input_indexes(self) -> list[int]: + return self._get_indexes(self.inputs) + + @override + def get_output_indexes(self) -> list[int]: + return self._get_indexes(self.outputs) + + @staticmethod + def _get_indexes(bitmask: int) -> list[int]: + """Return a list of indexes equivalent to some bitmask.""" + return [index for index in range(8) if (bitmask >> index) & 1] + + +SighashType: TypeAlias = SighashAll | SighashBitmask + + +class InputsOutputsLimit(BaseModel): + """A model representing inputs and outputs limits config.""" + max_inputs: int = Field(ge=1) + max_outputs: int = Field(ge=1) diff --git a/hathor/transaction/token_creation_tx.py b/hathor/transaction/token_creation_tx.py index 629050197..fe078d470 100644 --- a/hathor/transaction/token_creation_tx.py +++ b/hathor/transaction/token_creation_tx.py @@ -147,8 +147,8 @@ def get_sighash_all(self) -> bytes: :return: Serialization of the inputs, outputs and tokens :rtype: bytes """ - if self._sighash_cache: - return self._sighash_cache + if self._sighash_all_cache: + return self._sighash_all_cache struct_bytes = pack( _SIGHASH_ALL_FORMAT_STRING, @@ -169,7 +169,7 @@ def get_sighash_all(self) -> bytes: struct_bytes += b''.join(tx_outputs) struct_bytes += self.serialize_token_info() - self._sighash_cache = struct_bytes + self._sighash_all_cache = struct_bytes return struct_bytes diff --git a/hathor/transaction/transaction.py b/hathor/transaction/transaction.py index a51eaeffe..6865226c3 100644 --- a/hathor/transaction/transaction.py +++ b/hathor/transaction/transaction.py @@ -24,13 +24,14 @@ from hathor.exception import InvalidNewTransaction from hathor.transaction import TxInput, TxOutput, TxVersion from hathor.transaction.base_transaction import TX_HASH_SIZE, GenericVertex -from hathor.transaction.exceptions import InvalidToken +from hathor.transaction.exceptions import InvalidScriptError, InvalidToken from hathor.transaction.static_metadata import TransactionStaticMetadata from hathor.transaction.util import VerboseCallback, unpack, unpack_len from hathor.types import TokenUid, VertexId if TYPE_CHECKING: from hathor.conf.settings import HathorSettings + from hathor.transaction.scripts.sighash import CustomSighash from hathor.transaction.storage import TransactionStorage # noqa: F401 # Signal bits (B), version (B), token uids len (B) and inputs len (B), outputs len (B). @@ -88,7 +89,7 @@ def __init__( settings=settings ) self.tokens = tokens or [] - self._sighash_cache: Optional[bytes] = None + self._sighash_all_cache: Optional[bytes] = None self._sighash_data_cache: Optional[bytes] = None @property @@ -193,32 +194,37 @@ def get_sighash_all(self) -> bytes: # This method does not depend on the input itself, however we call it for each one to sign it. # For transactions that have many inputs there is a significant decrease on the verify time # when using this cache, so we call this method only once. - if self._sighash_cache: - return self._sighash_cache + if self._sighash_all_cache: + return self._sighash_all_cache + sighash = self._get_sighash(inputs=self.inputs, outputs=self.outputs) + self._sighash_all_cache = sighash + + return sighash + + def _get_sighash(self, *, inputs: list[TxInput], outputs: list[TxOutput]) -> bytes: + """Return the sighash data for this tx using a custom list of inputs and outputs.""" struct_bytes = bytearray( pack( _SIGHASH_ALL_FORMAT_STRING, self.signal_bits, self.version, len(self.tokens), - len(self.inputs), - len(self.outputs) + len(inputs), + len(outputs) ) ) for token_uid in self.tokens: struct_bytes += token_uid - for tx_input in self.inputs: + for tx_input in inputs: struct_bytes += tx_input.get_sighash_bytes() - for tx_output in self.outputs: + for tx_output in outputs: struct_bytes += bytes(tx_output) - ret = bytes(struct_bytes) - self._sighash_cache = ret - return ret + return bytes(struct_bytes) def get_sighash_all_data(self) -> bytes: """Return the sha256 hash of sighash_all""" @@ -227,6 +233,19 @@ def get_sighash_all_data(self) -> bytes: return self._sighash_data_cache + def get_custom_sighash_data(self, sighash: 'CustomSighash') -> bytes: + """ + Return the sighash data for this tx using a custom sighash type. + Inputs and outputs are selected according to indexes selected by the sighash. + """ + try: + inputs = [self.inputs[index] for index in sighash.get_input_indexes()] + outputs = [self.outputs[index] for index in sighash.get_output_indexes()] + except IndexError: + raise InvalidScriptError('Custom sighash selected nonexistent input/output.') + + return self._get_sighash(inputs=inputs, outputs=outputs) + def get_token_uid(self, index: int) -> TokenUid: """Returns the token uid with corresponding index from the tx token uid list. diff --git a/hathor/verification/transaction_verifier.py b/hathor/verification/transaction_verifier.py index 153fedd10..f8b0c6ed9 100644 --- a/hathor/verification/transaction_verifier.py +++ b/hathor/verification/transaction_verifier.py @@ -20,7 +20,7 @@ from hathor.profiler import get_cpu_profiler from hathor.reward_lock import get_spent_reward_locked_info from hathor.reward_lock.reward_lock import get_minimum_best_height -from hathor.transaction import BaseTransaction, Transaction, TxInput +from hathor.transaction import BaseTransaction, Transaction from hathor.transaction.exceptions import ( ConflictingInputs, DuplicatedParents, @@ -31,6 +31,7 @@ InvalidInputDataSize, InvalidToken, NoInputError, + OutputNotSelected, RewardLocked, ScriptError, TimestampError, @@ -38,6 +39,7 @@ TooManySigOps, WeightError, ) +from hathor.transaction.scripts.script_context import ScriptContext from hathor.transaction.transaction import TokenInfo from hathor.transaction.util import get_deposit_amount, get_withdraw_amount from hathor.types import TokenUid, VertexId @@ -104,7 +106,9 @@ def verify_inputs(self, tx: Transaction, *, skip_script: bool = False) -> None: from hathor.transaction.storage.exceptions import TransactionDoesNotExist spent_outputs: set[tuple[VertexId, int]] = set() - for input_tx in tx.inputs: + all_selected_outputs: set[int] = set() + + for input_index, input_tx in enumerate(tx.inputs): if len(input_tx.data) > self._settings.MAX_INPUT_DATA_SIZE: raise InvalidInputDataSize('size: {} and max-size: {}'.format( len(input_tx.data), self._settings.MAX_INPUT_DATA_SIZE @@ -127,7 +131,9 @@ def verify_inputs(self, tx: Transaction, *, skip_script: bool = False) -> None: )) if not skip_script: - self.verify_script(tx=tx, input_tx=input_tx, spent_tx=spent_tx) + script_context = self.verify_script(tx=tx, spent_tx=spent_tx, input_index=input_index) + selected_outputs = script_context.get_selected_outputs() + all_selected_outputs.update(selected_outputs) # check if any other input in this tx is spending the same output key = (input_tx.tx_id, input_tx.index) @@ -136,15 +142,26 @@ def verify_inputs(self, tx: Transaction, *, skip_script: bool = False) -> None: tx.hash_hex, input_tx.tx_id.hex(), input_tx.index)) spent_outputs.add(key) - def verify_script(self, *, tx: Transaction, input_tx: TxInput, spent_tx: BaseTransaction) -> None: + if not skip_script: + for index, _ in enumerate(tx.outputs): + if index not in all_selected_outputs: + raise OutputNotSelected(f'Output at index {index} is not signed by any input.') + + def verify_script( + self, + *, + tx: Transaction, + spent_tx: BaseTransaction, + input_index: int + ) -> ScriptContext: """ :type tx: Transaction - :type input_tx: TxInput :type spent_tx: Transaction + :type input_index: int """ from hathor.transaction.scripts import script_eval try: - script_eval(tx, input_tx, spent_tx) + return script_eval(tx=tx, spent_tx=spent_tx, input_index=input_index) except ScriptError as e: raise InvalidInputData(e) from e diff --git a/tests/tx/test_multisig.py b/tests/tx/test_multisig.py index 25222b90d..967a427ff 100644 --- a/tests/tx/test_multisig.py +++ b/tests/tx/test_multisig.py @@ -132,7 +132,7 @@ def test_spend_multisig(self): expected_dict = {'type': 'MultiSig', 'address': self.multisig_address_b58, 'timelock': None} self.assertEqual(cls_script.to_human_readable(), expected_dict) - script_eval(tx, tx_input, tx1) + script_eval(tx=tx, spent_tx=tx1, input_index=0) # Script error with self.assertRaises(ScriptError): diff --git a/tests/tx/test_nano_contracts.py b/tests/tx/test_nano_contracts.py index b23addf81..60bae9826 100644 --- a/tests/tx/test_nano_contracts.py +++ b/tests/tx/test_nano_contracts.py @@ -36,6 +36,6 @@ def test_match_values(self): input_data = NanoContractMatchValues.create_input_data( base64.b64decode(oracle_data), base64.b64decode(oracle_signature), base64.b64decode(pubkey)) txin = TxInput(b'aa', 0, input_data) - spent_tx = Transaction(outputs=[TxOutput(20, script)]) - tx = Transaction(outputs=[TxOutput(20, P2PKH.create_output_script(address))]) - script_eval(tx, txin, spent_tx) + spent_tx = Transaction(hash=b'aa', outputs=[TxOutput(20, script)]) + tx = Transaction(inputs=[txin], outputs=[TxOutput(20, P2PKH.create_output_script(address))]) + script_eval(tx=tx, spent_tx=spent_tx, input_index=0) diff --git a/tests/tx/test_scripts.py b/tests/tx/test_scripts.py index 34ce6ac25..a3fd98a92 100644 --- a/tests/tx/test_scripts.py +++ b/tests/tx/test_scripts.py @@ -176,22 +176,22 @@ def test_pushdata1(self): def test_dup(self): with self.assertRaises(MissingStackItems): - op_dup(ScriptContext(stack=[], logs=[], extras=Mock())) + op_dup(ScriptContext(stack=[], logs=[], extras=Mock(), settings=Mock())) stack = [1] - op_dup(ScriptContext(stack=stack, logs=[], extras=Mock())) + op_dup(ScriptContext(stack=stack, logs=[], extras=Mock(), settings=Mock())) self.assertEqual(stack[-1], stack[-2]) def test_equalverify(self): elem = b'a' with self.assertRaises(MissingStackItems): - op_equalverify(ScriptContext(stack=[elem], logs=[], extras=Mock())) + op_equalverify(ScriptContext(stack=[elem], logs=[], extras=Mock(), settings=Mock())) # no exception should be raised - op_equalverify(ScriptContext(stack=[elem, elem], logs=[], extras=Mock())) + op_equalverify(ScriptContext(stack=[elem, elem], logs=[], extras=Mock(), settings=Mock())) with self.assertRaises(EqualVerifyFailed): - op_equalverify(ScriptContext(stack=[elem, b'aaaa'], logs=[], extras=Mock())) + op_equalverify(ScriptContext(stack=[elem, b'aaaa'], logs=[], extras=Mock(), settings=Mock())) def test_checksig_raise_on_uncompressed_pubkey(self): """ Uncompressed pubkeys shoud not be accepted, even if they solve the signature @@ -213,11 +213,11 @@ def test_checksig_raise_on_uncompressed_pubkey(self): # ScriptError if pubkey is not a valid compressed public key # with wrong signature with self.assertRaises(ScriptError): - op_checksig(ScriptContext(stack=[b'123', pubkey_uncompressed], logs=[], extras=Mock())) + op_checksig(ScriptContext(stack=[b'123', pubkey_uncompressed], logs=[], extras=Mock(), settings=Mock())) # or with rigth one # this will make sure the signature is not made when parameters are wrong with self.assertRaises(ScriptError): - op_checksig(ScriptContext(stack=[signature, pubkey_uncompressed], logs=[], extras=Mock())) + op_checksig(ScriptContext(stack=[signature, pubkey_uncompressed], logs=[], extras=Mock(), settings=Mock())) def test_checksig_check_for_compressed_pubkey(self): """ Compressed pubkeys bytes representation always start with a byte 2 or 3 @@ -226,19 +226,19 @@ def test_checksig_check_for_compressed_pubkey(self): """ # ScriptError if pubkey is not a public key but starts with 2 or 3 with self.assertRaises(ScriptError): - op_checksig(ScriptContext(stack=[b'\x0233', b'\x0233'], logs=[], extras=Mock())) + op_checksig(ScriptContext(stack=[b'\x0233', b'\x0233'], logs=[], extras=Mock(), settings=Mock())) with self.assertRaises(ScriptError): - op_checksig(ScriptContext(stack=[b'\x0321', b'\x0321'], logs=[], extras=Mock())) + op_checksig(ScriptContext(stack=[b'\x0321', b'\x0321'], logs=[], extras=Mock(), settings=Mock())) # ScriptError if pubkey does not start with 2 or 3 with self.assertRaises(ScriptError): - op_checksig(ScriptContext(stack=[b'\x0123', b'\x0123'], logs=[], extras=Mock())) + op_checksig(ScriptContext(stack=[b'\x0123', b'\x0123'], logs=[], extras=Mock(), settings=Mock())) with self.assertRaises(ScriptError): - op_checksig(ScriptContext(stack=[b'\x0423', b'\x0423'], logs=[], extras=Mock())) + op_checksig(ScriptContext(stack=[b'\x0423', b'\x0423'], logs=[], extras=Mock(), settings=Mock())) def test_checksig(self): with self.assertRaises(MissingStackItems): - op_checksig(ScriptContext(stack=[1], logs=[], extras=Mock())) + op_checksig(ScriptContext(stack=[1], logs=[], extras=Mock(), settings=Mock())) block = self.genesis_blocks[0] @@ -253,15 +253,15 @@ def test_checksig(self): signature = self.genesis_private_key.sign(hashed_data, ec.ECDSA(hashes.SHA256())) pubkey_bytes = get_public_key_bytes_compressed(self.genesis_public_key) - extras = ScriptExtras(tx=tx, txin=Mock(), spent_tx=Mock()) + extras = ScriptExtras(tx=tx, spent_tx=block, input_index=0) # wrong signature puts False (0) on stack stack = [b'aaaaaaaaa', pubkey_bytes] - op_checksig(ScriptContext(stack=stack, logs=[], extras=extras)) + op_checksig(ScriptContext(stack=stack, logs=[], extras=extras, settings=Mock())) self.assertEqual(0, stack.pop()) stack = [signature, pubkey_bytes] - op_checksig(ScriptContext(stack=stack, logs=[], extras=extras)) + op_checksig(ScriptContext(stack=stack, logs=[], extras=extras, settings=Mock())) self.assertEqual(1, stack.pop()) def test_checksig_cache(self): @@ -278,22 +278,22 @@ def test_checksig_cache(self): signature = self.genesis_private_key.sign(hashed_data, ec.ECDSA(hashes.SHA256())) pubkey_bytes = get_public_key_bytes_compressed(self.genesis_public_key) - extras = ScriptExtras(tx=tx, txin=Mock(), spent_tx=Mock()) + extras = ScriptExtras(tx=tx, spent_tx=block, input_index=0) stack = [signature, pubkey_bytes] self.assertIsNone(tx._sighash_data_cache) - op_checksig(ScriptContext(stack=stack, logs=[], extras=extras)) + op_checksig(ScriptContext(stack=stack, logs=[], extras=extras, settings=Mock())) self.assertIsNotNone(tx._sighash_data_cache) self.assertEqual(1, stack.pop()) def test_hash160(self): with self.assertRaises(MissingStackItems): - op_hash160(ScriptContext(stack=[], logs=[], extras=Mock())) + op_hash160(ScriptContext(stack=[], logs=[], extras=Mock(), settings=Mock())) elem = b'aaaaaaaa' hash160 = get_hash160(elem) stack = [elem] - op_hash160(ScriptContext(stack=stack, logs=[], extras=Mock())) + op_hash160(ScriptContext(stack=stack, logs=[], extras=Mock(), settings=Mock())) self.assertEqual(hash160, stack.pop()) def test_checkdatasig_raise_on_uncompressed_pubkey(self): @@ -316,27 +316,33 @@ def test_checkdatasig_raise_on_uncompressed_pubkey(self): # with wrong signature stack = [data, b'123', pubkey_uncompressed] with self.assertRaises(ScriptError): - op_checkdatasig(ScriptContext(stack=stack, logs=[], extras=Mock())) + op_checkdatasig(ScriptContext(stack=stack, logs=[], extras=Mock(), settings=Mock())) # or with rigth one # this will make sure the signature is not made when parameters are wrong stack = [data, signature, pubkey_uncompressed] with self.assertRaises(ScriptError): - op_checkdatasig(ScriptContext(stack=stack, logs=[], extras=Mock())) + op_checkdatasig(ScriptContext(stack=stack, logs=[], extras=Mock(), settings=Mock())) def test_checkdatasig_check_for_compressed_pubkey(self): # ScriptError if pubkey is not a public key but starts with 2 or 3 with self.assertRaises(ScriptError): - op_checkdatasig(ScriptContext(stack=[b'\x0233', b'\x0233', b'\x0233'], logs=[], extras=Mock())) + op_checkdatasig( + ScriptContext(stack=[b'\x0233', b'\x0233', b'\x0233'], logs=[], extras=Mock(), settings=Mock()) + ) with self.assertRaises(ScriptError): - op_checkdatasig(ScriptContext(stack=[b'\x0321', b'\x0321', b'\x0321'], logs=[], extras=Mock())) + op_checkdatasig( + ScriptContext(stack=[b'\x0321', b'\x0321', b'\x0321'], logs=[], extras=Mock(), settings=Mock()) + ) # ScriptError if pubkey is not a public key with self.assertRaises(ScriptError): - op_checkdatasig(ScriptContext(stack=[b'\x0123', b'\x0123', b'\x0123'], logs=[], extras=Mock())) + op_checkdatasig( + ScriptContext(stack=[b'\x0123', b'\x0123', b'\x0123'], logs=[], extras=Mock(), settings=Mock()) + ) def test_checkdatasig(self): with self.assertRaises(MissingStackItems): - op_checkdatasig(ScriptContext(stack=[1, 1], logs=[], extras=Mock())) + op_checkdatasig(ScriptContext(stack=[1, 1], logs=[], extras=Mock(), settings=Mock())) data = b'some_random_data' signature = self.genesis_private_key.sign(data, ec.ECDSA(hashes.SHA256())) @@ -344,12 +350,12 @@ def test_checkdatasig(self): stack = [data, signature, pubkey_bytes] # no exception should be raised and data is left on stack - op_checkdatasig(ScriptContext(stack=stack, logs=[], extras=Mock())) + op_checkdatasig(ScriptContext(stack=stack, logs=[], extras=Mock(), settings=Mock())) self.assertEqual(data, stack.pop()) stack = [b'data_not_matching', signature, pubkey_bytes] with self.assertRaises(OracleChecksigFailed): - op_checkdatasig(ScriptContext(stack=stack, logs=[], extras=Mock())) + op_checkdatasig(ScriptContext(stack=stack, logs=[], extras=Mock(), settings=Mock())) def test_get_data_value(self): value0 = b'value0' @@ -370,7 +376,7 @@ def test_get_data_value(self): def test_data_strequal(self): with self.assertRaises(MissingStackItems): - op_data_strequal(ScriptContext(stack=[1, 1], logs=[], extras=Mock())) + op_data_strequal(ScriptContext(stack=[1, 1], logs=[], extras=Mock(), settings=Mock())) value0 = b'value0' value1 = b'vvvalue1' @@ -379,20 +385,20 @@ def test_data_strequal(self): data = (bytes([len(value0)]) + value0 + bytes([len(value1)]) + value1 + bytes([len(value2)]) + value2) stack = [data, 0, value0] - op_data_strequal(ScriptContext(stack=stack, logs=[], extras=Mock())) + op_data_strequal(ScriptContext(stack=stack, logs=[], extras=Mock(), settings=Mock())) self.assertEqual(stack.pop(), data) stack = [data, 1, value0] with self.assertRaises(VerifyFailed): - op_data_strequal(ScriptContext(stack=stack, logs=[], extras=Mock())) + op_data_strequal(ScriptContext(stack=stack, logs=[], extras=Mock(), settings=Mock())) stack = [data, b'\x00', value0] with self.assertRaises(VerifyFailed): - op_data_strequal(ScriptContext(stack=stack, logs=[], extras=Mock())) + op_data_strequal(ScriptContext(stack=stack, logs=[], extras=Mock(), settings=Mock())) def test_data_greaterthan(self): with self.assertRaises(MissingStackItems): - op_data_greaterthan(ScriptContext(stack=[1, 1], logs=[], extras=Mock())) + op_data_greaterthan(ScriptContext(stack=[1, 1], logs=[], extras=Mock(), settings=Mock())) value0 = struct.pack('!I', 1000) value1 = struct.pack('!I', 1) @@ -400,24 +406,24 @@ def test_data_greaterthan(self): data = (bytes([len(value0)]) + value0 + bytes([len(value1)]) + value1) stack = [data, 0, struct.pack('!I', 999)] - op_data_greaterthan(ScriptContext(stack=stack, logs=[], extras=Mock())) + op_data_greaterthan(ScriptContext(stack=stack, logs=[], extras=Mock(), settings=Mock())) self.assertEqual(stack.pop(), data) stack = [data, 1, struct.pack('!I', 0)] - op_data_greaterthan(ScriptContext(stack=stack, logs=[], extras=Mock())) + op_data_greaterthan(ScriptContext(stack=stack, logs=[], extras=Mock(), settings=Mock())) self.assertEqual(stack.pop(), data) with self.assertRaises(VerifyFailed): stack = [data, 1, struct.pack('!I', 1)] - op_data_greaterthan(ScriptContext(stack=stack, logs=[], extras=Mock())) + op_data_greaterthan(ScriptContext(stack=stack, logs=[], extras=Mock(), settings=Mock())) stack = [data, 1, b'not_an_int'] with self.assertRaises(VerifyFailed): - op_data_greaterthan(ScriptContext(stack=stack, logs=[], extras=Mock())) + op_data_greaterthan(ScriptContext(stack=stack, logs=[], extras=Mock(), settings=Mock())) stack = [data, b'\x00', struct.pack('!I', 0)] with self.assertRaises(VerifyFailed): - op_data_greaterthan(ScriptContext(stack=stack, logs=[], extras=Mock())) + op_data_greaterthan(ScriptContext(stack=stack, logs=[], extras=Mock(), settings=Mock())) def test_data_match_interval(self): with self.assertRaises(MissingStackItems): @@ -453,40 +459,40 @@ def test_data_match_interval(self): def test_data_match_value(self): with self.assertRaises(MissingStackItems): - op_data_match_value(ScriptContext(stack=[1, b'2'], logs=[], extras=Mock())) + op_data_match_value(ScriptContext(stack=[1, b'2'], logs=[], extras=Mock(), settings=Mock())) value0 = struct.pack('!I', 1000) data = (bytes([len(value0)]) + value0) stack = [data, 0, 'key1', struct.pack('!I', 1000), 'key2', struct.pack('!I', 1005), 'key3', bytes([2])] - op_data_match_value(ScriptContext(stack=stack, logs=[], extras=Mock())) + op_data_match_value(ScriptContext(stack=stack, logs=[], extras=Mock(), settings=Mock())) self.assertEqual(stack.pop(), 'key2') self.assertEqual(len(stack), 0) stack = [data, 0, 'key1', struct.pack('!I', 999), 'key2', struct.pack('!I', 1000), 'key3', bytes([2])] - op_data_match_value(ScriptContext(stack=stack, logs=[], extras=Mock())) + op_data_match_value(ScriptContext(stack=stack, logs=[], extras=Mock(), settings=Mock())) self.assertEqual(stack.pop(), 'key3') self.assertEqual(len(stack), 0) # missing 1 item on stack stack = [data, 0, 'key1', struct.pack('!I', 1000), 'key2', struct.pack('!I', 1000), bytes([2])] with self.assertRaises(MissingStackItems): - op_data_match_value(ScriptContext(stack=stack, logs=[], extras=Mock())) + op_data_match_value(ScriptContext(stack=stack, logs=[], extras=Mock(), settings=Mock())) # no value matches stack = [data, 0, 'key1', struct.pack('!I', 999), 'key2', struct.pack('!I', 1111), 'key3', bytes([2])] - op_data_match_value(ScriptContext(stack=stack, logs=[], extras=Mock())) + op_data_match_value(ScriptContext(stack=stack, logs=[], extras=Mock(), settings=Mock())) self.assertEqual(stack.pop(), 'key1') self.assertEqual(len(stack), 0) # value should be an integer stack = [data, 0, 'key1', struct.pack('!I', 100), 'key2', b'not_an_int', 'key3', bytes([2])] with self.assertRaises(VerifyFailed): - op_data_match_value(ScriptContext(stack=stack, logs=[], extras=Mock())) + op_data_match_value(ScriptContext(stack=stack, logs=[], extras=Mock(), settings=Mock())) def test_find_p2pkh(self): with self.assertRaises(MissingStackItems): - op_find_p2pkh(ScriptContext(stack=[], logs=[], extras=Mock())) + op_find_p2pkh(ScriptContext(stack=[], logs=[], extras=Mock(), settings=Mock())) addr1 = '15d14K5jMqsN2uwUEFqiPG5SoD7Vr1BfnH' addr2 = '1K35zJQeYrVzQAW7X3s7vbPKmngj5JXTBc' @@ -502,64 +508,72 @@ def test_find_p2pkh(self): out_genesis = P2PKH.create_output_script(genesis_address) from hathor.transaction import Transaction, TxInput, TxOutput - spent_tx = Transaction(outputs=[TxOutput(1, b'nano_contract_code')]) - txin = TxInput(b'dont_care', 0, b'data') + spent_tx = Transaction(hash=b'some_hash', outputs=[TxOutput(1, b'nano_contract_code')]) + txin = TxInput(b'some_hash', 0, b'data') # try with just 1 output stack = [genesis_address] - tx = Transaction(outputs=[TxOutput(1, out_genesis)]) - extras = ScriptExtras(tx=tx, txin=txin, spent_tx=spent_tx) - op_find_p2pkh(ScriptContext(stack=stack, logs=[], extras=extras)) + tx = Transaction(inputs=[txin], outputs=[TxOutput(1, out_genesis)]) + extras = ScriptExtras(tx=tx, spent_tx=spent_tx, input_index=0) + op_find_p2pkh(ScriptContext(stack=stack, logs=[], extras=extras, settings=Mock())) self.assertEqual(stack.pop(), 1) # several outputs and correct output among them stack = [genesis_address] - tx = Transaction(outputs=[TxOutput(1, out1), TxOutput(1, out2), TxOutput(1, out_genesis), TxOutput(1, out3)]) - extras = ScriptExtras(tx=tx, txin=txin, spent_tx=spent_tx) - op_find_p2pkh(ScriptContext(stack=stack, logs=[], extras=extras)) + tx = Transaction( + inputs=[txin], + outputs=[TxOutput(1, out1), TxOutput(1, out2), TxOutput(1, out_genesis), TxOutput(1, out3)] + ) + extras = ScriptExtras(tx=tx, spent_tx=spent_tx, input_index=0) + op_find_p2pkh(ScriptContext(stack=stack, logs=[], extras=extras, settings=Mock())) self.assertEqual(stack.pop(), 1) # several outputs without correct amount output stack = [genesis_address] - tx = Transaction(outputs=[TxOutput(1, out1), TxOutput(1, out2), TxOutput(2, out_genesis), TxOutput(1, out3)]) - extras = ScriptExtras(tx=tx, txin=txin, spent_tx=spent_tx) + tx = Transaction( + inputs=[txin], + outputs=[TxOutput(1, out1), TxOutput(1, out2), TxOutput(2, out_genesis), TxOutput(1, out3)] + ) + extras = ScriptExtras(tx=tx, spent_tx=spent_tx, input_index=0) with self.assertRaises(VerifyFailed): - op_find_p2pkh(ScriptContext(stack=stack, logs=[], extras=extras)) + op_find_p2pkh(ScriptContext(stack=stack, logs=[], extras=extras, settings=Mock())) # several outputs without correct address output stack = [genesis_address] - tx = Transaction(outputs=[TxOutput(1, out1), TxOutput(1, out2), TxOutput(1, out3)]) - extras = ScriptExtras(tx=tx, txin=txin, spent_tx=spent_tx) + tx = Transaction(inputs=[txin], outputs=[TxOutput(1, out1), TxOutput(1, out2), TxOutput(1, out3)]) + extras = ScriptExtras(tx=tx, spent_tx=spent_tx, input_index=0) with self.assertRaises(VerifyFailed): - op_find_p2pkh(ScriptContext(stack=stack, logs=[], extras=extras)) + op_find_p2pkh(ScriptContext(stack=stack, logs=[], extras=extras, settings=Mock())) def test_greaterthan_timestamp(self): with self.assertRaises(MissingStackItems): - op_greaterthan_timestamp(ScriptContext(stack=[], logs=[], extras=Mock())) + op_greaterthan_timestamp(ScriptContext(stack=[], logs=[], extras=Mock(), settings=Mock())) timestamp = 1234567 - from hathor.transaction import Transaction - tx = Transaction() + from hathor.transaction import Transaction, TxInput + spent_tx = Transaction(hash=b'some_hash') + tx_input = TxInput(tx_id=b'some_hash', index=0, data=b'') + tx = Transaction(inputs=[tx_input]) stack = [struct.pack('!I', timestamp)] - extras = ScriptExtras(tx=tx, txin=Mock(), spent_tx=Mock()) + extras = ScriptExtras(tx=tx, spent_tx=spent_tx, input_index=0) with self.assertRaises(TimeLocked): tx.timestamp = timestamp - 1 - op_greaterthan_timestamp(ScriptContext(stack=list(stack), logs=[], extras=extras)) + op_greaterthan_timestamp(ScriptContext(stack=list(stack), logs=[], extras=extras, settings=Mock())) with self.assertRaises(TimeLocked): tx.timestamp = timestamp - op_greaterthan_timestamp(ScriptContext(stack=list(stack), logs=[], extras=extras)) + op_greaterthan_timestamp(ScriptContext(stack=list(stack), logs=[], extras=extras, settings=Mock())) tx.timestamp = timestamp + 1 - op_greaterthan_timestamp(ScriptContext(stack=stack, logs=[], extras=extras)) + op_greaterthan_timestamp(ScriptContext(stack=stack, logs=[], extras=extras, settings=Mock())) self.assertEqual(len(stack), 0) def test_checkmultisig(self): with self.assertRaises(MissingStackItems): - op_checkmultisig(ScriptContext(stack=[], logs=[], extras=Mock())) + op_checkmultisig(ScriptContext(stack=[], logs=[], extras=Mock(), settings=Mock())) block = self.genesis_blocks[0] @@ -569,7 +583,7 @@ def test_checkmultisig(self): tx = Transaction(inputs=[txin], outputs=[txout]) data_to_sign = tx.get_sighash_all() - extras = ScriptExtras(tx=tx, txin=Mock(), spent_tx=Mock()) + extras = ScriptExtras(tx=tx, spent_tx=block, input_index=0) wallet = HDWallet() wallet._manually_initialize() @@ -598,92 +612,92 @@ def test_checkmultisig(self): stack = [ keys[0]['signature'], keys[2]['signature'], 2, keys[0]['pubkey'], keys[1]['pubkey'], keys[2]['pubkey'], 3 ] - op_checkmultisig(ScriptContext(stack=stack, logs=[], extras=extras)) + op_checkmultisig(ScriptContext(stack=stack, logs=[], extras=extras, settings=Mock())) self.assertEqual(1, stack.pop()) # New set of valid signatures stack = [ keys[0]['signature'], keys[1]['signature'], 2, keys[0]['pubkey'], keys[1]['pubkey'], keys[2]['pubkey'], 3 ] - op_checkmultisig(ScriptContext(stack=stack, logs=[], extras=extras)) + op_checkmultisig(ScriptContext(stack=stack, logs=[], extras=extras, settings=Mock())) self.assertEqual(1, stack.pop()) # Changing the signatures but they match stack = [ keys[1]['signature'], keys[2]['signature'], 2, keys[0]['pubkey'], keys[1]['pubkey'], keys[2]['pubkey'], 3 ] - op_checkmultisig(ScriptContext(stack=stack, logs=[], extras=extras)) + op_checkmultisig(ScriptContext(stack=stack, logs=[], extras=extras, settings=Mock())) self.assertEqual(1, stack.pop()) # Signatures are valid but in wrong order stack = [ keys[1]['signature'], keys[0]['signature'], 2, keys[0]['pubkey'], keys[1]['pubkey'], keys[2]['pubkey'], 3 ] - op_checkmultisig(ScriptContext(stack=stack, logs=[], extras=extras)) + op_checkmultisig(ScriptContext(stack=stack, logs=[], extras=extras, settings=Mock())) self.assertEqual(0, stack.pop()) # Adding wrong signature, so we get error stack = [ keys[0]['signature'], wrong_key['signature'], 2, keys[0]['pubkey'], keys[1]['pubkey'], keys[2]['pubkey'], 3 ] - op_checkmultisig(ScriptContext(stack=stack, logs=[], extras=extras)) + op_checkmultisig(ScriptContext(stack=stack, logs=[], extras=extras, settings=Mock())) self.assertEqual(0, stack.pop()) # Adding same signature twice, so we get error stack = [ keys[0]['signature'], keys[0]['signature'], 2, keys[0]['pubkey'], keys[1]['pubkey'], keys[2]['pubkey'], 3 ] - op_checkmultisig(ScriptContext(stack=stack, logs=[], extras=extras)) + op_checkmultisig(ScriptContext(stack=stack, logs=[], extras=extras, settings=Mock())) self.assertEqual(0, stack.pop()) # Adding less signatures than required, so we get error stack = [keys[0]['signature'], 2, keys[0]['pubkey'], keys[1]['pubkey'], keys[2]['pubkey'], 3] with self.assertRaises(MissingStackItems): - op_checkmultisig(ScriptContext(stack=stack, logs=[], extras=extras)) + op_checkmultisig(ScriptContext(stack=stack, logs=[], extras=extras, settings=Mock())) # Quantity of signatures is more than it should stack = [ keys[0]['signature'], keys[1]['signature'], 3, keys[0]['pubkey'], keys[1]['pubkey'], keys[2]['pubkey'], 3 ] with self.assertRaises(MissingStackItems): - op_checkmultisig(ScriptContext(stack=stack, logs=[], extras=extras)) + op_checkmultisig(ScriptContext(stack=stack, logs=[], extras=extras, settings=Mock())) # Quantity of pubkeys is more than it should stack = [ keys[0]['signature'], keys[1]['signature'], 2, keys[0]['pubkey'], keys[1]['pubkey'], keys[2]['pubkey'], 4 ] with self.assertRaises(InvalidStackData): - op_checkmultisig(ScriptContext(stack=stack, logs=[], extras=extras)) + op_checkmultisig(ScriptContext(stack=stack, logs=[], extras=extras, settings=Mock())) # Exception pubkey_count should be integer stack = [ keys[0]['signature'], keys[1]['signature'], 2, keys[0]['pubkey'], keys[1]['pubkey'], keys[2]['pubkey'], '3' ] with self.assertRaises(InvalidStackData): - op_checkmultisig(ScriptContext(stack=stack, logs=[], extras=extras)) + op_checkmultisig(ScriptContext(stack=stack, logs=[], extras=extras, settings=Mock())) # Exception not enough pub keys stack = [keys[0]['pubkey'], keys[1]['pubkey'], 3] with self.assertRaises(MissingStackItems): - op_checkmultisig(ScriptContext(stack=stack, logs=[], extras=extras)) + op_checkmultisig(ScriptContext(stack=stack, logs=[], extras=extras, settings=Mock())) # Exception stack empty after pubkeys stack = [keys[0]['pubkey'], keys[1]['pubkey'], keys[2]['pubkey'], 3] with self.assertRaises(MissingStackItems): - op_checkmultisig(ScriptContext(stack=stack, logs=[], extras=extras)) + op_checkmultisig(ScriptContext(stack=stack, logs=[], extras=extras, settings=Mock())) def test_equal(self): elem = b'a' with self.assertRaises(MissingStackItems): - op_equal(ScriptContext(stack=[elem], logs=[], extras=Mock())) + op_equal(ScriptContext(stack=[elem], logs=[], extras=Mock(), settings=Mock())) # no exception should be raised stack = [elem, elem] - op_equal(ScriptContext(stack=stack, logs=[], extras=Mock())) + op_equal(ScriptContext(stack=stack, logs=[], extras=Mock(), settings=Mock())) self.assertEqual(stack.pop(), 1) stack = [elem, b'aaaa'] - op_equal(ScriptContext(stack=stack, logs=[], extras=Mock())) + op_equal(ScriptContext(stack=stack, logs=[], extras=Mock(), settings=Mock())) self.assertEqual(stack.pop(), 0) def test_integer_opcode(self): diff --git a/tests/tx/test_tx.py b/tests/tx/test_tx.py index 6c842d656..747ffa1bc 100644 --- a/tests/tx/test_tx.py +++ b/tests/tx/test_tx.py @@ -1056,7 +1056,7 @@ def test_wallet_index(self): self.assertEqual(len(self.tx_storage.indexes.addresses.get_from_address(output3_address_b58)), 1) self.assertEqual(len(self.tx_storage.indexes.addresses.get_from_address(new_address_b58)), 1) - def test_sighash_cache(self): + def test_sighash_all_cache(self): from unittest import mock address = get_address_from_public_key(self.genesis_public_key) diff --git a/tests/wallet/test_wallet_hd.py b/tests/wallet/test_wallet_hd.py index 398b3767b..8eee8f737 100644 --- a/tests/wallet/test_wallet_hd.py +++ b/tests/wallet/test_wallet_hd.py @@ -39,7 +39,7 @@ def test_transaction_and_balance(self): tx1 = self.wallet.prepare_transaction_compute_inputs(Transaction, [out], self.tx_storage) tx1.update_hash() verifier = self.manager.verification_service.verifiers.tx - verifier.verify_script(tx=tx1, input_tx=tx1.inputs[0], spent_tx=block) + verifier.verify_script(tx=tx1, spent_tx=block, input_index=0) tx1.storage = self.tx_storage tx1.get_metadata().validation = ValidationState.FULL self.wallet.on_new_tx(tx1) @@ -60,7 +60,7 @@ def test_transaction_and_balance(self): tx2.storage = self.tx_storage tx2.update_hash() tx2.storage = self.tx_storage - verifier.verify_script(tx=tx2, input_tx=tx2.inputs[0], spent_tx=tx1) + verifier.verify_script(tx=tx2, spent_tx=tx1, input_index=0) tx2.get_metadata().validation = ValidationState.FULL tx2.init_static_metadata_from_storage(self._settings, self.tx_storage) self.tx_storage.save_transaction(tx2)