Skip to content

Commit

Permalink
separate BlockCache into the (simple) production use case and the (co…
Browse files Browse the repository at this point in the history
…mplex) test use case of mocking a blockchain object
  • Loading branch information
arvidn committed Aug 19, 2024
1 parent d34aee5 commit cf9de12
Show file tree
Hide file tree
Showing 9 changed files with 201 additions and 131 deletions.
15 changes: 8 additions & 7 deletions chia/_tests/core/full_node/stores/test_full_node_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from chia._tests.blockchain.blockchain_test_utils import _validate_and_add_block, _validate_and_add_block_no_error
from chia._tests.util.blockchain import create_blockchain
from chia._tests.util.blockchain_mock import BlockchainMock
from chia.consensus.blockchain import AddBlockResult, Blockchain
from chia.consensus.constants import ConsensusConstants
from chia.consensus.default_constants import DEFAULT_CONSTANTS
Expand All @@ -24,7 +25,6 @@
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.full_block import FullBlock
from chia.types.unfinished_block import UnfinishedBlock
from chia.util.block_cache import BlockCache
from chia.util.hash import std_hash
from chia.util.ints import uint8, uint16, uint32, uint64, uint128
from chia.util.recursive_replace import recursive_replace
Expand Down Expand Up @@ -338,7 +338,7 @@ async def test_basic_store(

assert (
store.get_finished_sub_slots(
BlockCache({}),
BlockchainMock({}),
None,
sub_slots[0].challenge_chain.challenge_chain_end_of_slot_vdf.challenge,
)
Expand Down Expand Up @@ -379,11 +379,12 @@ async def test_basic_store(
assert slot_i is not None
assert slot_i[0] == sub_slots[i]

assert store.get_finished_sub_slots(BlockCache({}), None, sub_slots[-1].challenge_chain.get_hash()) == sub_slots
assert store.get_finished_sub_slots(BlockCache({}), None, std_hash(b"not a valid hash")) is None
assert store.get_finished_sub_slots(BlockchainMock({}), None, sub_slots[-1].challenge_chain.get_hash()) == sub_slots
assert store.get_finished_sub_slots(BlockchainMock({}), None, std_hash(b"not a valid hash")) is None

assert (
store.get_finished_sub_slots(BlockCache({}), None, sub_slots[-2].challenge_chain.get_hash()) == sub_slots[:-1]
store.get_finished_sub_slots(BlockchainMock({}), None, sub_slots[-2].challenge_chain.get_hash())
== sub_slots[:-1]
)

# Test adding genesis peak
Expand Down Expand Up @@ -736,7 +737,7 @@ async def test_basic_store(
):
sp = get_signage_point(
custom_block_tools.constants,
BlockCache({}, {}),
BlockchainMock({}, {}),
None,
uint128(0),
uint8(i),
Expand Down Expand Up @@ -771,7 +772,7 @@ async def test_basic_store(
):
sp = get_signage_point(
custom_block_tools.constants,
BlockCache({}, {}),
BlockchainMock({}, {}),
None,
uint128(slot_offset * peak.sub_slot_iters),
uint8(i),
Expand Down
123 changes: 123 additions & 0 deletions chia/_tests/util/blockchain_mock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from __future__ import annotations

import logging
from typing import Dict, List, Optional

from chia.consensus.block_record import BlockRecord
from chia.consensus.blockchain_interface import BlockchainInterface
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.blockchain_format.sub_epoch_summary import SubEpochSummary
from chia.types.blockchain_format.vdf import VDFInfo
from chia.types.header_block import HeaderBlock
from chia.types.weight_proof import SubEpochChallengeSegment, SubEpochSegments
from chia.util.ints import uint32


class BlockchainMock(BlockchainInterface):
def __init__(
self,
blocks: Dict[bytes32, BlockRecord],
headers: Optional[Dict[bytes32, HeaderBlock]] = None,
height_to_hash: Optional[Dict[uint32, bytes32]] = None,
sub_epoch_summaries: Optional[Dict[uint32, SubEpochSummary]] = None,
):
if sub_epoch_summaries is None:
sub_epoch_summaries = {}
if height_to_hash is None:
height_to_hash = {}
if headers is None:
headers = {}
self._block_records = blocks
self._headers = headers
self._height_to_hash = height_to_hash
self._sub_epoch_summaries = sub_epoch_summaries
self._sub_epoch_segments: Dict[bytes32, SubEpochSegments] = {}
self.log = logging.getLogger(__name__)

def get_peak(self) -> Optional[BlockRecord]:
return None

def get_peak_height(self) -> Optional[uint32]:
return None

def block_record(self, header_hash: bytes32) -> BlockRecord:
return self._block_records[header_hash]

def height_to_block_record(self, height: uint32, check_db: bool = False) -> BlockRecord:
# Precondition: height is < peak height

header_hash: Optional[bytes32] = self.height_to_hash(height)
assert header_hash is not None

return self.block_record(header_hash)

def get_ses_heights(self) -> List[uint32]:
return sorted(self._sub_epoch_summaries.keys())

def get_ses(self, height: uint32) -> SubEpochSummary:
return self._sub_epoch_summaries[height]

def height_to_hash(self, height: uint32) -> Optional[bytes32]:
assert height in self._height_to_hash
return self._height_to_hash[height]

def contains_block(self, header_hash: bytes32) -> bool:
return header_hash in self._block_records

async def contains_block_from_db(self, header_hash: bytes32) -> bool:
return header_hash in self._block_records

def contains_height(self, height: uint32) -> bool:
return height in self._height_to_hash

async def warmup(self, fork_point: uint32) -> None:
return

async def get_block_records_in_range(self, start: int, stop: int) -> Dict[bytes32, BlockRecord]:
return self._block_records

async def get_block_records_at(self, heights: List[uint32]) -> List[BlockRecord]:
block_records: List[BlockRecord] = []
for height in heights:
block_records.append(self.height_to_block_record(height))
return block_records

def try_block_record(self, header_hash: bytes32) -> Optional[BlockRecord]:
return self._block_records.get(header_hash)

async def get_block_record_from_db(self, header_hash: bytes32) -> Optional[BlockRecord]:
return self._block_records[header_hash]

async def prev_block_hash(self, header_hashes: List[bytes32]) -> List[bytes32]:
ret = []
for h in header_hashes:
ret.append(self._block_records[h].prev_hash)
return ret

def remove_block_record(self, header_hash: bytes32) -> None:
del self._block_records[header_hash]

def add_block_record(self, block: BlockRecord) -> None:
self._block_records[block.header_hash] = block

async def get_header_blocks_in_range(
self, start: int, stop: int, tx_filter: bool = True
) -> Dict[bytes32, HeaderBlock]:
return self._headers

async def persist_sub_epoch_challenge_segments(
self, sub_epoch_summary_hash: bytes32, segments: List[SubEpochChallengeSegment]
) -> None:
self._sub_epoch_segments[sub_epoch_summary_hash] = SubEpochSegments(segments)

async def get_sub_epoch_challenge_segments(
self,
sub_epoch_summary_hash: bytes32,
) -> Optional[List[SubEpochChallengeSegment]]:
segments = self._sub_epoch_segments.get(sub_epoch_summary_hash)
if segments is None:
return None
return segments.challenge_segments

def seen_compact_proofs(self, vdf_info: VDFInfo, height: uint32) -> bool:
return False
6 changes: 3 additions & 3 deletions chia/_tests/wallet/sync/test_wallet_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from colorlog import getLogger

from chia._tests.connection_utils import disconnect_all, disconnect_all_and_reconnect
from chia._tests.util.blockchain_mock import BlockchainMock
from chia._tests.util.misc import wallet_height_at_least
from chia._tests.util.setup_nodes import OldSimulatorsAndWallets
from chia._tests.util.time_out_assert import time_out_assert, time_out_assert_not_none
Expand All @@ -39,7 +40,6 @@
from chia.types.full_block import FullBlock
from chia.types.peer_info import PeerInfo
from chia.util.batches import to_batches
from chia.util.block_cache import BlockCache
from chia.util.hash import std_hash
from chia.util.ints import uint32, uint64, uint128
from chia.wallet.nft_wallet.nft_wallet import NFTWallet
Expand Down Expand Up @@ -646,7 +646,7 @@ async def test_get_wp_fork_point(
) -> None:
blocks = default_10000_blocks
header_cache, height_to_hash, sub_blocks, summaries = await load_blocks_dont_validate(blocks, blockchain_constants)
wpf = WeightProofHandler(blockchain_constants, BlockCache(sub_blocks, header_cache, height_to_hash, summaries))
wpf = WeightProofHandler(blockchain_constants, BlockchainMock(sub_blocks, header_cache, height_to_hash, summaries))
wp1 = await wpf.get_proof_of_weight(header_cache[height_to_hash[uint32(9_000)]].header_hash)
assert wp1 is not None
wp2 = await wpf.get_proof_of_weight(header_cache[height_to_hash[uint32(9_030)]].header_hash)
Expand Down Expand Up @@ -1410,7 +1410,7 @@ async def test_bad_peak_mismatch(
full_node_server = full_node.server
blocks = default_1000_blocks
header_cache, height_to_hash, sub_blocks, summaries = await load_blocks_dont_validate(blocks, blockchain_constants)
wpf = WeightProofHandler(blockchain_constants, BlockCache(sub_blocks, header_cache, height_to_hash, summaries))
wpf = WeightProofHandler(blockchain_constants, BlockchainMock(sub_blocks, header_cache, height_to_hash, summaries))

await wallet_server.start_client(PeerInfo(self_hostname, full_node_server.get_port()), None)

Expand Down
Loading

0 comments on commit cf9de12

Please sign in to comment.