From 9159491fa63c4f083ce16d46495b352f14483064 Mon Sep 17 00:00:00 2001 From: Gabriel Levcovitz Date: Fri, 22 Mar 2024 13:00:44 -0300 Subject: [PATCH] refactor(mypy): add stricter rules to unittest and utils --- pyproject.toml | 2 + tests/p2p/test_double_spending.py | 3 +- tests/tx/test_indexes2.py | 2 +- tests/unittest.py | 112 +++++++++++++++++++----------- tests/utils.py | 66 ++++++++++++++---- 5 files changed, 128 insertions(+), 57 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 371317680..7fd634751 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -165,6 +165,8 @@ module = [ "tests.p2p.*", "tests.pubsub.*", "tests.simulation.*", + "tests.unittest", + "tests.utils", ] disallow_untyped_defs = true diff --git a/tests/p2p/test_double_spending.py b/tests/p2p/test_double_spending.py index f3f908a68..21b74d620 100644 --- a/tests/p2p/test_double_spending.py +++ b/tests/p2p/test_double_spending.py @@ -4,6 +4,7 @@ from hathor.manager import HathorManager from hathor.simulator.utils import add_new_blocks from hathor.transaction import Transaction +from hathor.util import not_none from tests import unittest from tests.utils import add_blocks_unlock_reward, add_new_tx @@ -23,7 +24,7 @@ def setUp(self) -> None: def _add_new_transactions(self, manager: HathorManager, num_txs: int) -> list[Transaction]: txs = [] for _ in range(num_txs): - address = self.get_address(0) + address = not_none(self.get_address(0)) value = self.rng.choice([5, 10, 15, 20]) tx = add_new_tx(manager, address, value) txs.append(tx) diff --git a/tests/tx/test_indexes2.py b/tests/tx/test_indexes2.py index b8df4d9eb..970903cc6 100644 --- a/tests/tx/test_indexes2.py +++ b/tests/tx/test_indexes2.py @@ -64,7 +64,7 @@ def test_timestamp_index(self): # XXX: we verified they're the same, doesn't matter which we pick: idx = idx_memory hashes = hashes_memory - self.log.debug('indexes match', idx=idx, hashes=unittest.shorten_hash(hashes)) + self.log.debug('indexes match', idx=idx, hashes=unittest.short_hashes(hashes)) if idx is None: break offset_variety.add(idx[1]) diff --git a/tests/unittest.py b/tests/unittest.py index a7468da6f..38f6c1c08 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -3,7 +3,7 @@ import shutil import tempfile import time -from typing import Iterator, Optional +from typing import Callable, Collection, Iterable, Iterator, Optional from unittest import main as ut_main from structlog import get_logger @@ -16,13 +16,17 @@ from hathor.daa import DifficultyAdjustmentAlgorithm, TestMode from hathor.event import EventManager from hathor.event.storage import EventStorage +from hathor.manager import HathorManager from hathor.p2p.peer_id import PeerId +from hathor.p2p.sync_v1.agent import NodeSyncTimestamp +from hathor.p2p.sync_v2.agent import NodeBlockSync from hathor.p2p.sync_version import SyncVersion from hathor.pubsub import PubSubManager from hathor.reactor import ReactorProtocol as Reactor, get_global_reactor from hathor.simulator.clock import MemoryReactorHeapClock -from hathor.transaction import BaseTransaction +from hathor.transaction import BaseTransaction, Block, Transaction from hathor.transaction.storage.transaction_storage import TransactionStorage +from hathor.types import VertexId from hathor.util import Random, not_none from hathor.wallet import BaseWallet, HDWallet, Wallet from tests.test_memory_reactor_clock import TestMemoryReactorClock @@ -33,9 +37,8 @@ USE_MEMORY_STORAGE = os.environ.get('HATHOR_TEST_MEMORY_STORAGE', 'false').lower() == 'true' -def shorten_hash(container): - container_type = type(container) - return container_type(h[-2:].hex() for h in container) +def short_hashes(container: Collection[bytes]) -> Iterable[str]: + return map(lambda hash_bytes: hash_bytes[-2:].hex(), container) def _load_peer_id_pool(file_path: Optional[str] = None) -> Iterator[PeerId]: @@ -50,7 +53,7 @@ def _load_peer_id_pool(file_path: Optional[str] = None) -> Iterator[PeerId]: yield PeerId.create_from_json(peer_id_dict) -def _get_default_peer_id_pool_filepath(): +def _get_default_peer_id_pool_filepath() -> str: this_file_path = os.path.dirname(__file__) file_name = 'peer_id_pool.json' file_path = os.path.join(this_file_path, file_name) @@ -109,8 +112,8 @@ class TestCase(unittest.TestCase): use_memory_storage: bool = USE_MEMORY_STORAGE seed_config: Optional[int] = None - def setUp(self): - self.tmpdirs = [] + def setUp(self) -> None: + self.tmpdirs: list[str] = [] self.clock = TestMemoryReactorClock() self.clock.advance(time.time()) self.log = logger.new() @@ -118,10 +121,10 @@ def setUp(self): self.seed = secrets.randbits(64) if self.seed_config is None else self.seed_config self.log.info('set seed', seed=self.seed) self.rng = Random(self.seed) - self._pending_cleanups = [] + self._pending_cleanups: list[Callable] = [] self._settings = get_global_settings() - def tearDown(self): + def tearDown(self) -> None: self.clean_tmpdirs() for fn in self._pending_cleanups: fn() @@ -144,12 +147,12 @@ def get_random_peer_id_from_pool(self, pool: Optional[list[PeerId]] = None, pool.remove(peer_id) return peer_id - def mkdtemp(self): + def mkdtemp(self) -> str: tmpdir = tempfile.mkdtemp() self.tmpdirs.append(tmpdir) return tmpdir - def _create_test_wallet(self, unlocked=False): + def _create_test_wallet(self, unlocked: bool = False) -> Wallet: """ Generate a Wallet with a number of keypairs for testing :rtype: Wallet """ @@ -169,14 +172,14 @@ def get_builder(self, network: str) -> TestBuilder: .set_network(network) return builder - def create_peer_from_builder(self, builder, start_manager=True): + def create_peer_from_builder(self, builder: Builder, start_manager: bool = True) -> HathorManager: artifacts = builder.build() manager = artifacts.manager if artifacts.rocksdb_storage: self._pending_cleanups.append(artifacts.rocksdb_storage.close) - manager.avg_time_between_blocks = 0.0001 + # manager.avg_time_between_blocks = 0.0001 # FIXME: This property is not defined. Fix this. if start_manager: manager.start() @@ -277,7 +280,7 @@ def create_peer( # type: ignore[no-untyped-def] return manager - def run_to_completion(self): + def run_to_completion(self) -> None: """ This will advance the test's clock until all calls scheduled are done. """ for call in self.clock.getDelayedCalls(): @@ -300,7 +303,11 @@ def assertIsTopological(self, tx_sequence: Iterator[BaseTransaction], message: O self.assertIn(dep, valid_deps, message) valid_deps.add(tx.hash) - def _syncVersionFlags(self, enable_sync_v1=None, enable_sync_v2=None): + def _syncVersionFlags( + self, + enable_sync_v1: bool | None = None, + enable_sync_v2: bool | None = None + ) -> tuple[bool, bool]: """Internal: use this to check and get the flags and optionally provide override values.""" if enable_sync_v1 is None: assert hasattr(self, '_enable_sync_v1'), ('`_enable_sync_v1` has no default by design, either set one on ' @@ -313,19 +320,19 @@ def _syncVersionFlags(self, enable_sync_v1=None, enable_sync_v2=None): assert enable_sync_v1 or enable_sync_v2, 'enable at least one sync version' return enable_sync_v1, enable_sync_v2 - def assertTipsEqual(self, manager1, manager2): + def assertTipsEqual(self, manager1: HathorManager, manager2: HathorManager) -> None: _, enable_sync_v2 = self._syncVersionFlags() if enable_sync_v2: self.assertTipsEqualSyncV2(manager1, manager2) else: self.assertTipsEqualSyncV1(manager1, manager2) - def assertTipsNotEqual(self, manager1, manager2): + def assertTipsNotEqual(self, manager1: HathorManager, manager2: HathorManager) -> None: s1 = set(manager1.tx_storage.get_all_tips()) s2 = set(manager2.tx_storage.get_all_tips()) self.assertNotEqual(s1, s2) - def assertTipsEqualSyncV1(self, manager1, manager2): + def assertTipsEqualSyncV1(self, manager1: HathorManager, manager2: HathorManager) -> None: # XXX: this is the original implementation of assertTipsEqual s1 = set(manager1.tx_storage.get_all_tips()) s2 = set(manager2.tx_storage.get_all_tips()) @@ -335,39 +342,45 @@ def assertTipsEqualSyncV1(self, manager1, manager2): s2 = set(manager2.tx_storage.get_tx_tips()) self.assertEqual(s1, s2) - def assertTipsEqualSyncV2(self, manager1, manager2, *, strict_sync_v2_indexes=True): + def assertTipsEqualSyncV2( + self, + manager1: HathorManager, + manager2: HathorManager, + *, + strict_sync_v2_indexes: bool = True + ) -> None: # tx tips if strict_sync_v2_indexes: - tips1 = manager1.tx_storage.indexes.mempool_tips.get() - tips2 = manager2.tx_storage.indexes.mempool_tips.get() + tips1 = not_none(not_none(manager1.tx_storage.indexes).mempool_tips).get() + tips2 = not_none(not_none(manager2.tx_storage.indexes).mempool_tips).get() else: tips1 = {tx.hash for tx in manager1.tx_storage.iter_mempool_tips_from_best_index()} tips2 = {tx.hash for tx in manager2.tx_storage.iter_mempool_tips_from_best_index()} - self.log.debug('tx tips1', len=len(tips1), list=shorten_hash(tips1)) - self.log.debug('tx tips2', len=len(tips2), list=shorten_hash(tips2)) + self.log.debug('tx tips1', len=len(tips1), list=short_hashes(tips1)) + self.log.debug('tx tips2', len=len(tips2), list=short_hashes(tips2)) self.assertEqual(tips1, tips2) # best block s1 = set(manager1.tx_storage.get_best_block_tips()) s2 = set(manager2.tx_storage.get_best_block_tips()) - self.log.debug('block tips1', len=len(s1), list=shorten_hash(s1)) - self.log.debug('block tips2', len=len(s2), list=shorten_hash(s2)) + self.log.debug('block tips1', len=len(s1), list=short_hashes(s1)) + self.log.debug('block tips2', len=len(s2), list=short_hashes(s2)) self.assertEqual(s1, s2) # best block (from height index) - b1 = manager1.tx_storage.indexes.height.get_tip() - b2 = manager2.tx_storage.indexes.height.get_tip() + b1 = not_none(manager1.tx_storage.indexes).height.get_tip() + b2 = not_none(manager2.tx_storage.indexes).height.get_tip() self.assertIn(b1, s2) self.assertIn(b2, s1) - def assertConsensusEqual(self, manager1, manager2): + def assertConsensusEqual(self, manager1: HathorManager, manager2: HathorManager) -> None: _, enable_sync_v2 = self._syncVersionFlags() if enable_sync_v2: self.assertConsensusEqualSyncV2(manager1, manager2) else: self.assertConsensusEqualSyncV1(manager1, manager2) - def assertConsensusEqualSyncV1(self, manager1, manager2): + def assertConsensusEqualSyncV1(self, manager1: HathorManager, manager2: HathorManager) -> None: self.assertEqual(manager1.tx_storage.get_vertices_count(), manager2.tx_storage.get_vertices_count()) for tx1 in manager1.tx_storage.get_all_transactions(): tx2 = manager2.tx_storage.get_transaction(tx1.hash) @@ -381,12 +394,20 @@ def assertConsensusEqualSyncV1(self, manager1, manager2): self.assertIsNone(tx2_meta.voided_by) else: # If tx1 is voided, then tx2 must be voided. + assert tx1_meta.voided_by is not None + assert tx2_meta.voided_by is not None self.assertGreaterEqual(len(tx1_meta.voided_by), 1) self.assertGreaterEqual(len(tx2_meta.voided_by), 1) # Hard verification # self.assertEqual(tx1_meta.voided_by, tx2_meta.voided_by) - def assertConsensusEqualSyncV2(self, manager1, manager2, *, strict_sync_v2_indexes=True): + def assertConsensusEqualSyncV2( + self, + manager1: HathorManager, + manager2: HathorManager, + *, + strict_sync_v2_indexes: bool = True + ) -> None: # The current sync algorithm does not propagate voided blocks/txs # so the count might be different even though the consensus is equal # One peer might have voided txs that the other does not have @@ -397,7 +418,9 @@ def assertConsensusEqualSyncV2(self, manager1, manager2, *, strict_sync_v2_index # the following is specific to sync-v2 # helper function: - def get_all_executed_or_voided(tx_storage): + def get_all_executed_or_voided( + tx_storage: TransactionStorage + ) -> tuple[set[VertexId], set[VertexId], set[VertexId]]: """Get all txs separated into three sets: executed, voided, partial""" tx_executed = set() tx_voided = set() @@ -424,14 +447,16 @@ def get_all_executed_or_voided(tx_storage): self.log.debug('node1 rest', len_voided=len(tx_voided1), len_partial=len(tx_partial1)) self.log.debug('node2 rest', len_voided=len(tx_voided2), len_partial=len(tx_partial2)) - def assertConsensusValid(self, manager): + def assertConsensusValid(self, manager: HathorManager) -> None: for tx in manager.tx_storage.get_all_transactions(): if tx.is_block: + assert isinstance(tx, Block) self.assertBlockConsensusValid(tx) else: + assert isinstance(tx, Transaction) self.assertTransactionConsensusValid(tx) - def assertBlockConsensusValid(self, block): + def assertBlockConsensusValid(self, block: Block) -> None: self.assertTrue(block.is_block) if not block.parents: # Genesis @@ -442,7 +467,8 @@ def assertBlockConsensusValid(self, block): parent_meta = parent.get_metadata() self.assertIsNone(parent_meta.voided_by) - def assertTransactionConsensusValid(self, tx): + def assertTransactionConsensusValid(self, tx: Transaction) -> None: + assert tx.storage is not None self.assertFalse(tx.is_block) meta = tx.get_metadata() if meta.voided_by and tx.hash in meta.voided_by: @@ -462,7 +488,7 @@ def assertTransactionConsensusValid(self, tx): spent_meta = spent_tx.get_metadata() if spent_meta.voided_by is not None: - self.assertIsNotNone(meta.voided_by) + assert meta.voided_by is not None self.assertTrue(spent_meta.voided_by) self.assertTrue(meta.voided_by) self.assertTrue(spent_meta.voided_by.issubset(meta.voided_by)) @@ -470,30 +496,32 @@ def assertTransactionConsensusValid(self, tx): for parent in tx.get_parents(): parent_meta = parent.get_metadata() if parent_meta.voided_by is not None: - self.assertIsNotNone(meta.voided_by) + assert meta.voided_by is not None self.assertTrue(parent_meta.voided_by) self.assertTrue(meta.voided_by) self.assertTrue(parent_meta.voided_by.issubset(meta.voided_by)) - def assertSyncedProgress(self, node_sync): + def assertSyncedProgress(self, node_sync: NodeSyncTimestamp | NodeBlockSync) -> None: """Check "synced" status of p2p-manager, uses self._enable_sync_vX to choose which check to run.""" enable_sync_v1, enable_sync_v2 = self._syncVersionFlags() if enable_sync_v2: + assert isinstance(node_sync, NodeBlockSync) self.assertV2SyncedProgress(node_sync) elif enable_sync_v1: + assert isinstance(node_sync, NodeSyncTimestamp) self.assertV1SyncedProgress(node_sync) - def assertV1SyncedProgress(self, node_sync): + def assertV1SyncedProgress(self, node_sync: NodeSyncTimestamp) -> None: self.assertEqual(node_sync.synced_timestamp, node_sync.peer_timestamp) - def assertV2SyncedProgress(self, node_sync): + def assertV2SyncedProgress(self, node_sync: NodeBlockSync) -> None: self.assertEqual(node_sync.synced_block, node_sync.peer_best_block) - def clean_tmpdirs(self): + def clean_tmpdirs(self) -> None: for tmpdir in self.tmpdirs: shutil.rmtree(tmpdir) - def clean_pending(self, required_to_quiesce=True): + def clean_pending(self, required_to_quiesce: bool = True) -> None: """ This handy method cleans all pending tasks from the reactor. diff --git a/tests/utils.py b/tests/utils.py index 4e9a212ab..bffffd613 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -5,9 +5,10 @@ import time import urllib.parse from dataclasses import dataclass -from typing import Optional +from typing import Any, Optional import requests +from cryptography.hazmat.primitives.asymmetric import ec from hathorlib.scripts import DataScript from twisted.internet.task import Clock @@ -19,7 +20,7 @@ from hathor.manager import HathorManager from hathor.mining.cpu_mining_service import CpuMiningService from hathor.simulator.utils import add_new_block, add_new_blocks, gen_new_double_spending, gen_new_tx -from hathor.transaction import BaseTransaction, Transaction, TxInput, TxOutput +from hathor.transaction import BaseTransaction, Block, Transaction, TxInput, TxOutput from hathor.transaction.scripts import P2PKH, HathorScript, Opcode, parse_address_script from hathor.transaction.token_creation_tx import TokenCreationTransaction from hathor.transaction.util import get_deposit_amount @@ -134,7 +135,13 @@ def add_new_double_spending(manager: HathorManager, *, use_same_parents: bool = return tx -def add_new_tx(manager, address, value, advance_clock=None, propagate=True): +def add_new_tx( + manager: HathorManager, + address: str, + value: int, + advance_clock: int | None = None, + propagate: bool = True +) -> Transaction: """ Create, resolve and propagate a new tx :param manager: Manager object to handle the creation @@ -153,11 +160,16 @@ def add_new_tx(manager, address, value, advance_clock=None, propagate=True): if propagate: manager.propagate_tx(tx, fails_silently=False) if advance_clock: - manager.reactor.advance(advance_clock) + manager.reactor.advance(advance_clock) # type: ignore[attr-defined] return tx -def add_new_transactions(manager, num_txs, advance_clock=None, propagate=True): +def add_new_transactions( + manager: HathorManager, + num_txs: int, + advance_clock: int | None = None, + propagate: bool = True +) -> list[Transaction]: """ Create, resolve and propagate some transactions :param manager: Manager object to handle the creation @@ -178,7 +190,7 @@ def add_new_transactions(manager, num_txs, advance_clock=None, propagate=True): return txs -def add_blocks_unlock_reward(manager): +def add_blocks_unlock_reward(manager: HathorManager) -> list[Block]: """This method adds new blocks to a 'burn address' to make sure the existing block rewards can be spent. It uses a 'burn address' so the manager's wallet is not impacted. @@ -186,7 +198,14 @@ def add_blocks_unlock_reward(manager): return add_new_blocks(manager, settings.REWARD_SPEND_MIN_BLOCKS, advance_clock=1, address=BURN_ADDRESS) -def run_server(hostname='localhost', listen=8005, status=8085, bootstrap=None, tries=100, alive_for_at_least_sec=3): +def run_server( + hostname: str = 'localhost', + listen: int = 8005, + status: int = 8085, + bootstrap: str | None = None, + tries: int = 100, + alive_for_at_least_sec: int = 3 +) -> subprocess.Popen[bytes]: """ Starts a full node in a subprocess running the cli command :param hostname: Hostname used to be accessed by other peers @@ -249,7 +268,14 @@ def run_server(hostname='localhost', listen=8005, status=8085, bootstrap=None, t return process -def request_server(path, method, host='http://localhost', port=8085, data=None, prefix=settings.API_VERSION_PREFIX): +def request_server( + path: str, + method: str, + host: str = 'http://localhost', + port: int = 8085, + data: dict[str, Any] | None = None, + prefix: str = settings.API_VERSION_PREFIX +) -> dict[str, Any]: """ Execute a request for status server :param path: Url path of the request @@ -283,8 +309,14 @@ def request_server(path, method, host='http://localhost', port=8085, data=None, return response.json() -def execute_mining(path='mining', *, count, host='http://localhost', port=8085, data=None, - prefix=settings.API_VERSION_PREFIX): +def execute_mining( + path: str = 'mining', + *, + count: int, + host: str = 'http://localhost', + port: int = 8085, + prefix: str = settings.API_VERSION_PREFIX +) -> None: """Execute a mining on a given server""" from hathor.cli.mining import create_parser, execute partial_url = '{}:{}/{}/'.format(host, port, prefix) @@ -294,8 +326,16 @@ def execute_mining(path='mining', *, count, host='http://localhost', port=8085, execute(args) -def execute_tx_gen(*, count, address=None, value=None, timestamp=None, host='http://localhost', port=8085, data=None, - prefix=settings.API_VERSION_PREFIX): +def execute_tx_gen( + *, + count: int, + address: str | None = None, + value: int | None = None, + timestamp: str | None = None, + host: str = 'http://localhost', + port: int = 8085, + prefix: str = settings.API_VERSION_PREFIX +) -> None: """Execute a tx generator on a given server""" from hathor.cli.tx_generator import create_parser, execute url = '{}:{}/{}/'.format(host, port, prefix) @@ -311,7 +351,7 @@ def execute_tx_gen(*, count, address=None, value=None, timestamp=None, host='htt execute(args) -def get_genesis_key(): +def get_genesis_key() -> ec.EllipticCurvePrivateKeyWithSerialization: private_key_bytes = base64.b64decode( 'MIGEAgEAMBAGByqGSM49AgEGBSuBBAAKBG0wawIBAQQgOCgCddzDZsfKgiMJLOt97eov9RLwHeePyBIK2WPF8MChRA' 'NCAAQ/XSOK+qniIY0F3X+lDrb55VQx5jWeBLhhzZnH6IzGVTtlAj9Ki73DVBm5+VXK400Idd6ddzS7FahBYYC7IaTl'