diff --git a/chain/src/chain.rs b/chain/src/chain.rs index b12ec166c6..08a9f24011 100644 --- a/chain/src/chain.rs +++ b/chain/src/chain.rs @@ -29,7 +29,7 @@ use core::core::hash::{Hash, Hashed, ZERO_HASH}; use core::core::merkle_proof::MerkleProof; use core::core::verifier_cache::VerifierCache; use core::core::{ - Block, BlockHeader, BlockSums, Output, OutputIdentifier, Transaction, TxKernelEntry, + Block, BlockHeader, BlockSums, Output, OutputIdentifier, Transaction, TxKernel, TxKernelEntry, }; use core::global; use core::pow; @@ -179,7 +179,8 @@ impl Chain { let store = Arc::new(chain_store); // open the txhashset, creating a new one if necessary - let mut txhashset = txhashset::TxHashSet::open(db_root.clone(), store.clone(), None)?; + let mut txhashset = + txhashset::TxHashSet::open(db_root.clone(), store.clone(), None, genesis.hash())?; setup_head(genesis.clone(), store.clone(), &mut txhashset)?; @@ -345,6 +346,22 @@ impl Chain { Ok(()) } + /// Attempt to add new kernels to the kernel mmr. + /// This is only ever used during sync. + pub fn sync_kernels( + &self, + blocks: &Vec<(Hash, Vec)>, + opts: Options, + ) -> Result<(), Error> { + let mut txhashset = self.txhashset.write(); + let batch = self.store.batch()?; + let mut ctx = self.new_ctx(opts, batch, &mut txhashset)?; + + pipe::sync_kernels(blocks, &mut ctx)?; + + Ok(()) + } + fn new_ctx<'a>( &self, opts: Options, @@ -789,17 +806,37 @@ impl Chain { } let header = self.get_block_header(&h)?; - txhashset::zip_write(self.db_root.clone(), txhashset_data, &header)?; + let mut kernels_already_synced = false; + { + let existing_txhashset = self.txhashset.read(); + let num_kernels = existing_txhashset.num_kernels(); + if num_kernels >= header.kernel_mmr_size { + kernels_already_synced = true; + } + } - let mut txhashset = - txhashset::TxHashSet::open(self.db_root.clone(), self.store.clone(), Some(&header))?; + txhashset::zip_write( + self.db_root.clone(), + txhashset_data, + &header, + kernels_already_synced, + )?; + + let mut txhashset = txhashset::TxHashSet::open( + self.db_root.clone(), + self.store.clone(), + Some(&header), + header.hash(), + )?; // The txhashset.zip contains the output, rangeproof and kernel MMRs. // We must rebuild the header MMR ourselves based on the headers in our db. self.rebuild_header_mmr(&Tip::from_header(&header), &mut txhashset)?; // Validate the full kernel history (kernel MMR root for every block header). - self.validate_kernel_history(&header, &txhashset)?; + if !kernels_already_synced { + self.validate_kernel_history(&header, &txhashset)?; + } // all good, prepare a new batch and update all the required records debug!("txhashset_write: rewinding a 2nd time (writeable)"); @@ -973,6 +1010,24 @@ impl Chain { txhashset.last_n_kernel(distance) } + /// kernels by insertion index + pub fn get_kernels_by_insertion_index(&self, start_index: u64, max: u64) -> Vec { + let mut txhashset = self.txhashset.write(); + txhashset.kernels_by_insertion_index(start_index, max).1 + } + + /// returns the number of leaves in the kernel mmr + pub fn get_num_kernels(&self) -> u64 { + let txhashset = self.txhashset.read(); + txhashset.num_kernels() + } + /// returns the number of leaves in the kernel mmr + pub fn get_kernel_root_validated_tip(&self) -> Result { + let txhashset = self.txhashset.read(); + let hash = txhashset.kernel_root_validated_tip(); + self.get_block_header(&hash) + } + /// outputs by insertion index pub fn unspent_outputs_by_insertion_index( &self, diff --git a/chain/src/pipe.rs b/chain/src/pipe.rs index d6304661fb..0af00cabfc 100644 --- a/chain/src/pipe.rs +++ b/chain/src/pipe.rs @@ -27,7 +27,7 @@ use core::consensus; use core::core::hash::{Hash, Hashed}; use core::core::verifier_cache::VerifierCache; use core::core::Committed; -use core::core::{Block, BlockHeader, BlockSums}; +use core::core::{Block, BlockHeader, BlockSums, TxKernel}; use core::global; use core::pow; use error::{Error, ErrorKind}; @@ -276,6 +276,68 @@ pub fn sync_block_headers( } } +/// Process the kernels. +/// This is only ever used during sync. +pub fn sync_kernels( + blocks: &Vec<(Hash, Vec)>, + ctx: &mut BlockContext, +) -> Result<(), Error> { + let first_block = match blocks.first() { + Some(block) => { + debug!( + "pipe: sync_kernels: {} blocks from {}", + blocks.len(), + block.0 + ); + block + } + _ => return Ok(()), + }; + + let header = ctx.batch.get_block_header(&first_block.0)?; + + let next_kernel_index = match header.height { + 0 => 0, + height => { + let header = match ctx.batch.get_header_by_height(height - 1) { + Ok(header) => header, + Err(e) => { + error!("sync_kernels: Could not find header: {:?}", e); + return Ok(()); + } + }; + header.kernel_mmr_size + } + }; + + let num_kernels = ctx.txhashset.num_kernels(); + if num_kernels < next_kernel_index { + // TODO: A convenient way to store and process these later + // would allow nodes to request batches of kernels in parallel. + return Err(ErrorKind::TxHashSetErr("Previous kernels missing".to_string()).into()); + } + + let kernels_to_add: Vec = blocks.iter().flat_map(|block| block.1.clone()).collect(); + if num_kernels > (next_kernel_index + kernels_to_add.len() as u64) { + debug!( + "pipe: sync_kernels: kernels from index {} not needed.", + next_kernel_index, + ); + return Ok(()); + } + + txhashset::extending(&mut ctx.txhashset, &mut ctx.batch, |extension| { + // Rewinding kernel mmr to correct kernel index. Probably unnecessary, but playing it safe. + extension.rewind_kernel_mmr(next_kernel_index)?; + + for block in blocks { + extension.apply_kernels(&block.0, &block.1)?; + } + Ok(()) + })?; + Ok(()) +} + /// Process block header as part of "header first" block propagation. /// We validate the header but we do not store it or update header head based /// on this. We will update these once we get the block back after requesting diff --git a/chain/src/txhashset/txhashset.rs b/chain/src/txhashset/txhashset.rs index f99f63d02e..fcb560edbf 100644 --- a/chain/src/txhashset/txhashset.rs +++ b/chain/src/txhashset/txhashset.rs @@ -15,6 +15,7 @@ //! Utility structs to handle the 3 MMRs (output, rangeproof, //! kernel) along the overall header MMR conveniently and transactionally. +use std::cell::Cell; use std::collections::HashSet; use std::fs::{self, File}; use std::path::{Path, PathBuf}; @@ -74,6 +75,7 @@ impl HashOnlyMMRHandle { struct PMMRHandle { backend: PMMRBackend, last_pos: u64, + root_validated_tip: Hash, } impl PMMRHandle { @@ -83,12 +85,17 @@ impl PMMRHandle { file_name: &str, prunable: bool, header: Option<&BlockHeader>, + chain_tip: Hash, ) -> Result, Error> { let path = Path::new(root_dir).join(sub_dir).join(file_name); fs::create_dir_all(path.clone())?; let backend = PMMRBackend::new(path.to_str().unwrap().to_string(), prunable, header)?; let last_pos = backend.unpruned_size(); - Ok(PMMRHandle { backend, last_pos }) + Ok(PMMRHandle { + backend, + last_pos, + root_validated_tip: chain_tip, + }) } } @@ -133,6 +140,7 @@ impl TxHashSet { root_dir: String, commit_index: Arc, header: Option<&BlockHeader>, + chain_tip: Hash, ) -> Result { Ok(TxHashSet { header_pmmr_h: HashOnlyMMRHandle::new( @@ -147,6 +155,7 @@ impl TxHashSet { OUTPUT_SUBDIR, true, header, + chain_tip, )?, rproof_pmmr_h: PMMRHandle::new( &root_dir, @@ -154,6 +163,7 @@ impl TxHashSet { RANGE_PROOF_SUBDIR, true, header, + chain_tip, )?, kernel_pmmr_h: PMMRHandle::new( &root_dir, @@ -161,6 +171,7 @@ impl TxHashSet { KERNEL_SUBDIR, false, None, + chain_tip, )?, commit_index, }) @@ -213,6 +224,28 @@ impl TxHashSet { kernel_pmmr.get_last_n_insertions(distance) } + /// returns kernels from the given insertion (leaf) index up to the + /// specified limit. Also returns the last index actually populated + pub fn kernels_by_insertion_index( + &mut self, + start_index: u64, + max_count: u64, + ) -> (u64, Vec) { + let kernel_pmmr: PMMR = + PMMR::at(&mut self.kernel_pmmr_h.backend, self.kernel_pmmr_h.last_pos); + kernel_pmmr.elements_from_insertion_index(start_index, max_count) + } + + /// returns the number of kernels (leaves) in the kernel_pmmr + pub fn num_kernels(&self) -> u64 { + pmmr::n_leaves(self.kernel_pmmr_h.last_pos) + } + + /// returns the root_validated_tip of the kernel_pmmr. + pub fn kernel_root_validated_tip(&self) -> Hash { + self.kernel_pmmr_h.root_validated_tip + } + /// returns outputs from the given insertion (leaf) index up to the /// specified limit. Also returns the last index actually populated pub fn outputs_by_insertion_index( @@ -410,6 +443,7 @@ where let sizes: (u64, u64, u64, u64); let res: Result; let rollback: bool; + let root_validated_tips: (Hash, Hash, Hash); // We want to use the current head of the most work chain unless // we explicitly rewind the extension. @@ -428,6 +462,7 @@ where rollback = extension.rollback; sizes = extension.sizes(); + root_validated_tips = extension.root_validated_tips(); } match res { @@ -457,6 +492,9 @@ where trees.output_pmmr_h.last_pos = sizes.1; trees.rproof_pmmr_h.last_pos = sizes.2; trees.kernel_pmmr_h.last_pos = sizes.3; + trees.output_pmmr_h.root_validated_tip = root_validated_tips.0; + trees.rproof_pmmr_h.root_validated_tip = root_validated_tips.1; + trees.kernel_pmmr_h.root_validated_tip = root_validated_tips.2; } trace!("TxHashSet extension done."); @@ -737,6 +775,11 @@ pub struct Extension<'a> { rproof_pmmr: PMMR<'a, RangeProof, PMMRBackend>, kernel_pmmr: PMMR<'a, TxKernelEntry, PMMRBackend>, + // DAVID: Make these a part of PMMR + output_root_validated_tip: Cell, + rproof_root_validated_tip: Cell, + kernel_root_validated_tip: Cell, + /// Rollback flag. rollback: bool, @@ -796,6 +839,9 @@ impl<'a> Extension<'a> { &mut trees.kernel_pmmr_h.backend, trees.kernel_pmmr_h.last_pos, ), + output_root_validated_tip: Cell::new(trees.output_pmmr_h.root_validated_tip), + rproof_root_validated_tip: Cell::new(trees.rproof_pmmr_h.root_validated_tip), + kernel_root_validated_tip: Cell::new(trees.kernel_pmmr_h.root_validated_tip), rollback: false, batch, } @@ -863,8 +909,17 @@ impl<'a> Extension<'a> { self.apply_input(input)?; } - for kernel in b.kernels() { - self.apply_kernel(kernel)?; + // Because of kernel sync, it's possible we already have the kernels for this block. + let kernel_size = pmmr::n_leaves(self.kernel_pmmr.last_pos); + if b.header.kernel_mmr_size > kernel_size { + let num_kernels_behind = b.header.kernel_mmr_size - kernel_size; + let mut num_kernels_added = 0; + for kernel in b.kernels() { + if num_kernels_added < num_kernels_behind { + num_kernels_added += 1; + self.apply_kernel(kernel)?; + } + } } // Update the header on the extension to reflect the block we just applied. @@ -950,8 +1005,38 @@ impl<'a> Extension<'a> { Ok(output_pos) } + /// Iterates through the kernels and adds them to the kernel_mmr, + /// verifying signatures and kernel_mmr_roots as we go. + pub fn apply_kernels(&mut self, hash: &Hash, kernels: &Vec) -> Result<(), Error> { + let kernel_root_validated_tip = self.kernel_root_validated_tip.get(); + let header = self.batch.get_block_header(hash)?; + if header.prev_hash != kernel_root_validated_tip { + return Err(ErrorKind::InvalidTxHashSet(format!( + "Previous hash does not match kernel tip" + )).into()); + } + + for kernel in kernels { + // Ensure kernel is self-consistent + kernel.verify()?; + + // Apply the kernel to the kernel MMR. + self.apply_kernel(kernel)?; + } + + if header.kernel_root == self.kernel_root() { + self.kernel_root_validated_tip.set(hash.clone()); + Ok(()) + } else { + Err(ErrorKind::InvalidTxHashSet(format!( + "Kernel root for block {} does not match", + hash + )).into()) + } + } + /// Push kernel onto MMR (hash and data files). - fn apply_kernel(&mut self, kernel: &TxKernel) -> Result<(), Error> { + pub fn apply_kernel(&mut self, kernel: &TxKernel) -> Result<(), Error> { self.kernel_pmmr .push(TxKernelEntry::from(kernel.clone())) .map_err(&ErrorKind::TxHashSetErr)?; @@ -1020,12 +1105,50 @@ impl<'a> Extension<'a> { &rewind_rm_pos, )?; + let hash = header.hash(); + self.output_root_validated_tip.set(hash); + self.kernel_root_validated_tip.set(hash); + // Update our header to reflect the one we rewound to. self.header = header.clone(); Ok(()) } + /// Rewinds the kernel mmr to the provided leaf index. + pub fn rewind_kernel_mmr(&mut self, next_kernel_index: u64) -> Result<(), Error> { + debug!( + "Rewind kernel_pmmr to next_kernel_index {}", + next_kernel_index, + ); + + // Update kernel_root_validated_tip + if next_kernel_index == 0 { + self.kernel_root_validated_tip + .set(self.batch.get_header_by_height(0)?.hash()); + } else if next_kernel_index > pmmr::n_leaves(self.kernel_pmmr.last_pos) { + return Err(ErrorKind::TxHashSetErr("Trying to rewind forward.".to_string()).into()); + } else { + let current_tip = self.kernel_root_validated_tip.get(); + + let mut new_tip = self.batch.get_block_header(¤t_tip)?; + while new_tip.kernel_mmr_size > next_kernel_index { + new_tip = self.batch.get_block_header(&new_tip.prev_hash)?; + } + + self.kernel_root_validated_tip.set(new_tip.hash()); + } + + // Rewind kernel_pmmr + let kernel_pos = pmmr::insertion_to_pmmr_index(next_kernel_index); + + self.kernel_pmmr + .rewind(kernel_pos, &Bitmap::create()) + .map_err(&ErrorKind::TxHashSetErr)?; + + Ok(()) + } + /// Rewinds the MMRs to the provided positions, given the output and /// kernel we want to rewind to. fn rewind_to_pos( @@ -1066,6 +1189,11 @@ impl<'a> Extension<'a> { } } + /// Get the root of the current kernel MMR. + pub fn kernel_root(&self) -> Hash { + self.kernel_pmmr.root() + } + /// Get the root of the current header MMR. pub fn header_root(&self) -> Hash { self.header_pmmr.root() @@ -1096,10 +1224,24 @@ impl<'a> Extension<'a> { { Err(ErrorKind::InvalidRoot.into()) } else { + self.output_root_validated_tip.set(self.header.hash()); + self.rproof_root_validated_tip.set(self.header.hash()); + self.kernel_root_validated_tip.set(self.header.hash()); Ok(()) } } + /// Gets the hashes of the last header whose output root was validated, + /// the last header whose rangeproof root was validated, + /// and the last header whose kernel root was validated. + pub fn root_validated_tips(&self) -> (Hash, Hash, Hash) { + ( + self.output_root_validated_tip.get(), + self.rproof_root_validated_tip.get(), + self.kernel_root_validated_tip.get(), + ) + } + /// Validate the provided header by comparing its prev_root to the /// root of the current header MMR. pub fn validate_header_root(&self, header: &BlockHeader) -> Result<(), Error> { @@ -1107,8 +1249,8 @@ impl<'a> Extension<'a> { return Ok(()); } - let roots = self.roots(); - if roots.header_root != header.prev_root { + let header_root = self.header_pmmr.root(); + if header_root != header.prev_root { Err(ErrorKind::InvalidRoot.into()) } else { Ok(()) @@ -1391,10 +1533,21 @@ pub fn zip_write( root_dir: String, txhashset_data: File, header: &BlockHeader, + kernels_already_synced: bool, ) -> Result<(), Error> { let txhashset_path = Path::new(&root_dir).join(TXHASHSET_SUBDIR); fs::create_dir_all(txhashset_path.clone())?; - zip::decompress(txhashset_data, &txhashset_path) + + let skip_subdirs: HashSet = match kernels_already_synced { + true => [KERNEL_SUBDIR] + .iter() + .cloned() + .map(|s| String::from(s)) + .collect(), + false => HashSet::new(), + }; + + zip::decompress(txhashset_data, &txhashset_path, &skip_subdirs) .map_err(|ze| ErrorKind::Other(ze.to_string()))?; check_and_remove_files(&txhashset_path, header) } diff --git a/chain/tests/test_txhashset.rs b/chain/tests/test_txhashset.rs index 40f872773e..318567a3a1 100644 --- a/chain/tests/test_txhashset.rs +++ b/chain/tests/test_txhashset.rs @@ -27,6 +27,7 @@ use std::time::{SystemTime, UNIX_EPOCH}; use chain::store::ChainStore; use chain::txhashset; +use core::core::hash::ZERO_HASH; use core::core::BlockHeader; use util::file; @@ -44,12 +45,14 @@ fn test_unexpected_zip() { let db_env = Arc::new(store::new_env(db_root.clone())); let chain_store = ChainStore::new(db_env).unwrap(); let store = Arc::new(chain_store); - txhashset::TxHashSet::open(db_root.clone(), store.clone(), None).unwrap(); + txhashset::TxHashSet::open(db_root.clone(), store.clone(), None, ZERO_HASH).unwrap(); // First check if everything works out of the box assert!(txhashset::zip_read(db_root.clone(), &BlockHeader::default(), Some(rand)).is_ok()); let zip_path = Path::new(&db_root).join(format!("txhashset_snapshot_{}.zip", rand)); let zip_file = File::open(&zip_path).unwrap(); - assert!(txhashset::zip_write(db_root.clone(), zip_file, &BlockHeader::default()).is_ok()); + assert!( + txhashset::zip_write(db_root.clone(), zip_file, &BlockHeader::default(), false).is_ok() + ); // Remove temp txhashset dir fs::remove_dir_all(Path::new(&db_root).join(format!("txhashset_zip_{}", rand))).unwrap(); // Then add strange files in the original txhashset folder @@ -64,7 +67,9 @@ fn test_unexpected_zip() { fs::remove_dir_all(Path::new(&db_root).join(format!("txhashset_zip_{}", rand))).unwrap(); let zip_file = File::open(zip_path).unwrap(); - assert!(txhashset::zip_write(db_root.clone(), zip_file, &BlockHeader::default()).is_ok()); + assert!( + txhashset::zip_write(db_root.clone(), zip_file, &BlockHeader::default(), false).is_ok() + ); // Check that the txhashset dir dos not contains the strange files let txhashset_path = Path::new(&db_root).join("txhashset"); assert!(txhashset_contains_expected_files( diff --git a/core/src/core/pmmr/pmmr.rs b/core/src/core/pmmr/pmmr.rs index 0850bc517f..b1d64cfff1 100644 --- a/core/src/core/pmmr/pmmr.rs +++ b/core/src/core/pmmr/pmmr.rs @@ -616,7 +616,7 @@ pub fn bintree_rightmost(num: u64) -> u64 { num - bintree_postorder_height(num) } -/// Gets the position of the rightmost node (i.e. leaf) beneath the provided subtree root. +/// Gets the position of the leftmost node (i.e. leaf) beneath the provided subtree root. pub fn bintree_leftmost(num: u64) -> u64 { let height = bintree_postorder_height(num); num + 2 - (2 << height) diff --git a/p2p/src/lib.rs b/p2p/src/lib.rs index 38611a0877..27979f6575 100644 --- a/p2p/src/lib.rs +++ b/p2p/src/lib.rs @@ -57,5 +57,5 @@ pub use serv::{DummyAdapter, Server}; pub use store::{PeerData, State}; pub use types::{ Capabilities, ChainAdapter, Direction, Error, P2PConfig, PeerInfo, ReasonForBan, Seeding, - TxHashSetRead, MAX_BLOCK_HEADERS, MAX_LOCATORS, MAX_PEER_ADDRS, + TxHashSetRead, MAX_BLOCK_HEADERS, MAX_KERNEL_BLOCKS, MAX_LOCATORS, MAX_PEER_ADDRS, }; diff --git a/p2p/src/msg.rs b/p2p/src/msg.rs index 3a768c77ad..46e4fb6075 100644 --- a/p2p/src/msg.rs +++ b/p2p/src/msg.rs @@ -22,10 +22,14 @@ use std::{thread, time}; use core::consensus; use core::core::hash::Hash; use core::core::BlockHeader; +use core::core::TxKernel; use core::pow::Difficulty; use core::ser::{self, Readable, Reader, Writeable, Writer}; -use types::{Capabilities, Error, ReasonForBan, MAX_BLOCK_HEADERS, MAX_LOCATORS, MAX_PEER_ADDRS}; +use types::{ + Capabilities, Error, ReasonForBan, MAX_BLOCK_HEADERS, MAX_KERNEL_BLOCKS, MAX_LOCATORS, + MAX_PEER_ADDRS, +}; /// Current latest version of the protocol pub const PROTOCOL_VERSION: u32 = 1; @@ -70,6 +74,8 @@ enum_from_primitive! { BanReason = 18, GetTransaction = 19, TransactionKernel = 20, + GetKernels = 21, + Kernels = 22, } } @@ -97,6 +103,8 @@ fn max_msg_size(msg_type: Type) -> u64 { Type::BanReason => 64, Type::GetTransaction => 32, Type::TransactionKernel => 32, + Type::GetKernels => 48, + Type::Kernels => 2 + ((34 + (114 * 20_000/*kerns per block*/)) * MAX_KERNEL_BLOCKS as u64), } } @@ -764,3 +772,75 @@ impl Readable for TxHashSetArchive { }) } } + +/// Request to get the kernels for each block starting with the block at the specified height. +pub struct GetKernels { + /// The height of the first block kernels are being requested for. + pub first_block_height: u64, +} + +impl Writeable for GetKernels { + fn write(&self, writer: &mut W) -> Result<(), ser::Error> { + writer.write_u64(self.first_block_height)?; + Ok(()) + } +} + +impl Readable for GetKernels { + fn read(reader: &mut Reader) -> Result { + Ok(GetKernels { + first_block_height: reader.read_u64()?, + }) + } +} + +/// A block hash and the kernels belonging to that block. +pub struct BlockKernels { + /// Hash of the block the kernels belong to. + pub hash: Hash, + /// The block's kernels in the order they appear in the Kernel MMR leafset. + pub kernels: Vec, +} + +/// Response to a Get kernels request containing the requested kernels. +pub struct Kernels { + /// The blocks, in order, and their corresponding kernels. + pub blocks: Vec, +} + +impl Writeable for Kernels { + fn write(&self, writer: &mut W) -> Result<(), ser::Error> { + writer.write_u16(self.blocks.len() as u16)?; + + for block in &self.blocks { + block.hash.write(writer)?; + + writer.write_u16(block.kernels.len() as u16)?; + for kernel in &block.kernels { + kernel.write(writer)?; + } + } + Ok(()) + } +} + +impl Readable for Kernels { + fn read(reader: &mut Reader) -> Result { + let num_blocks = reader.read_u16()?; + let mut blocks = Vec::with_capacity(num_blocks as usize); + + for _ in 0..num_blocks { + let hash = Hash::read(reader)?; + + let num_kernels = reader.read_u16()?; + let mut kernels = Vec::with_capacity(num_kernels as usize); + for _ in 0..num_kernels { + kernels.push(TxKernel::read(reader)?); + } + + blocks.push(BlockKernels { hash, kernels }); + } + + Ok(Kernels { blocks }) + } +} diff --git a/p2p/src/peer.rs b/p2p/src/peer.rs index b833602d04..f1ea8a29a8 100644 --- a/p2p/src/peer.rs +++ b/p2p/src/peer.rs @@ -23,7 +23,7 @@ use core::core::hash::{Hash, Hashed}; use core::pow::Difficulty; use core::{core, global}; use handshake::Handshake; -use msg::{self, BanReason, GetPeerAddrs, Locator, Ping, TxHashSetRequest}; +use msg::{self, BanReason, GetKernels, GetPeerAddrs, Locator, Ping, TxHashSetRequest}; use protocol::Protocol; use types::{ Capabilities, ChainAdapter, Error, NetAdapter, P2PConfig, PeerInfo, ReasonForBan, TxHashSetRead, @@ -415,6 +415,34 @@ impl Peer { ) } + /// Requests the kernels for each block, starting with the block at specified height. + /// NOTE: Only sends the request if remote peer has ENHANCED_TXHASHSET_HIST capability. + pub fn send_kernel_request(&self, first_block_height: u64) -> Result { + if self + .info + .capabilities + .contains(Capabilities::ENHANCED_TXHASHSET_HIST) + { + trace!( + "Asking {} for kernels starting with block at {}.", + self.info.addr, + first_block_height, + ); + self.connection + .as_ref() + .unwrap() + .lock() + .send(&GetKernels { first_block_height }, msg::Type::GetKernels)?; + Ok(true) + } else { + trace!( + "Not requesting kernels from {} (peer not capable)", + self.info.addr + ); + Ok(false) + } + } + /// Stops the peer, closing its connection pub fn stop(&self) { stop_with_connection(&self.connection.as_ref().unwrap().lock()); @@ -579,6 +607,18 @@ impl ChainAdapter for TrackingAdapter { self.adapter .txhashset_download_update(start_time, downloaded_size, total_size) } + + fn read_kernels(&self, first_block_height: u64) -> Vec<(Hash, Vec)> { + self.adapter.read_kernels(first_block_height) + } + + fn kernels_received( + &self, + blocks: &Vec<(Hash, Vec)>, + peer_addr: SocketAddr, + ) -> bool { + self.adapter.kernels_received(blocks, peer_addr) + } } impl NetAdapter for TrackingAdapter { diff --git a/p2p/src/peers.rs b/p2p/src/peers.rs index 446f0bbc82..a67d3dc6a5 100644 --- a/p2p/src/peers.rs +++ b/p2p/src/peers.rs @@ -581,6 +581,18 @@ impl ChainAdapter for Peers { self.adapter .txhashset_download_update(start_time, downloaded_size, total_size) } + + fn read_kernels(&self, first_block_height: u64) -> Vec<(Hash, Vec)> { + self.adapter.read_kernels(first_block_height) + } + + fn kernels_received( + &self, + blocks: &Vec<(Hash, Vec)>, + peer_addr: SocketAddr, + ) -> bool { + self.adapter.kernels_received(blocks, peer_addr) + } } impl NetAdapter for Peers { diff --git a/p2p/src/protocol.rs b/p2p/src/protocol.rs index b94e1f3f07..ede1860748 100644 --- a/p2p/src/protocol.rs +++ b/p2p/src/protocol.rs @@ -27,8 +27,8 @@ use core::{global, ser}; use util::{RateCounter, RwLock}; use msg::{ - read_exact, BanReason, GetPeerAddrs, Headers, Locator, PeerAddrs, Ping, Pong, SockAddr, - TxHashSetArchive, TxHashSetRequest, Type, + read_exact, BanReason, BlockKernels, GetKernels, GetPeerAddrs, Headers, Kernels, Locator, + PeerAddrs, Ping, Pong, SockAddr, TxHashSetArchive, TxHashSetRequest, Type, }; use types::{Error, NetAdapter}; @@ -336,6 +336,36 @@ impl MessageHandler for Protocol { Ok(None) } + Type::GetKernels => { + // DAVID: Check Capabilities first? + // Retrieve kernels from the kernel MMR + let request: GetKernels = msg.body()?; + let kernels_by_block = adapter.read_kernels(request.first_block_height); + + let blocks: Vec = kernels_by_block + .iter() + .map(|block| BlockKernels { + hash: block.0, + kernels: block.1.clone(), + }).collect(); + + // serialize and send all the kernels over + Ok(Some(msg.respond(Type::Kernels, Kernels { blocks }))) + } + + Type::Kernels => { + let kernels: Kernels = msg.body()?; + let blocks: Vec<(Hash, Vec)> = kernels + .blocks + .iter() + .map(|block| (block.hash, block.kernels.clone())) + .collect(); + + adapter.kernels_received(&blocks, self.addr); + + Ok(None) + } + _ => { debug!("unknown message type {:?}", msg.header.msg_type); Ok(None) diff --git a/p2p/src/serv.rs b/p2p/src/serv.rs index 3d6ebf9abf..e5a6a0dd54 100644 --- a/p2p/src/serv.rs +++ b/p2p/src/serv.rs @@ -260,6 +260,18 @@ impl ChainAdapter for DummyAdapter { ) -> bool { false } + + fn read_kernels(&self, first_block_height: u64) -> Vec<(Hash, Vec)> { + vec![] + } + + fn kernels_received( + &self, + blocks: &Vec<(Hash, Vec)>, + peer_addr: SocketAddr, + ) -> bool { + true + } } impl NetAdapter for DummyAdapter { diff --git a/p2p/src/types.rs b/p2p/src/types.rs index 5f6b87078a..5909b68fa8 100644 --- a/p2p/src/types.rs +++ b/p2p/src/types.rs @@ -40,6 +40,9 @@ pub const MAX_PEER_ADDRS: u32 = 256; /// Maximum number of block header hashes to send as part of a locator pub const MAX_LOCATORS: u32 = 20; +/// Maximum number of blocks a peer should ever send kernels for +pub const MAX_KERNEL_BLOCKS: u32 = 64; + /// How long a banned peer should be banned for const BAN_WINDOW: i64 = 10800; @@ -210,6 +213,8 @@ bitflags! { const PEER_LIST = 0b00000100; /// Can broadcast and request txs by kernel hash. const TX_KERNEL_HASH = 0b00001000; + /// Provides ability to request kernels separately from the rest of the TxHashSet. + const ENHANCED_TXHASHSET_HIST = 0b00010000; /// All nodes right now are "full nodes". /// Some nodes internally may maintain longer block histories (archival_mode) @@ -373,7 +378,7 @@ pub trait ChainAdapter: Sync + Send { fn get_block(&self, h: Hash) -> Option; /// Provides a reading view into the current txhashset state as well as - /// the required indexes for a consumer to rewind to a consistant state + /// the required indexes for a consumer to rewind to a consistent state /// at the provided block hash. fn txhashset_read(&self, h: Hash) -> Option; @@ -396,6 +401,17 @@ pub trait ChainAdapter: Sync + Send { /// read as a zip file, unzipped and the resulting state files should be /// rewound to the provided indexes. fn txhashset_write(&self, h: Hash, txhashset_data: File, peer_addr: SocketAddr) -> bool; + + /// Finds a list of kernels starting from the given index. + /// Returns kernels from the chain containing the block with the given hash. + fn read_kernels(&self, first_block_height: u64) -> Vec<(Hash, Vec)>; + + /// A set of kernels has been received. + fn kernels_received( + &self, + blocks: &Vec<(Hash, Vec)>, + peer_addr: SocketAddr, + ) -> bool; } /// Additional methods required by the protocol that don't need to be diff --git a/p2p/tests/ser_deser.rs b/p2p/tests/ser_deser.rs index f7710a982e..eca41dbd64 100644 --- a/p2p/tests/ser_deser.rs +++ b/p2p/tests/ser_deser.rs @@ -64,7 +64,7 @@ fn test_capabilities() { ); assert_eq!( p2p::types::Capabilities::from_bits_truncate(0b11110111 as u32), - p2p::types::Capabilities::FULL_NODE + p2p::types::Capabilities::ENHANCED_TXHASHSET_HIST | p2p::types::Capabilities::FULL_NODE ); assert_eq!( p2p::types::Capabilities::from_bits_truncate(0b00100111 as u32), diff --git a/servers/src/common/adapters.rs b/servers/src/common/adapters.rs index bbff07deb5..15c2d684bc 100644 --- a/servers/src/common/adapters.rs +++ b/servers/src/common/adapters.rs @@ -15,6 +15,7 @@ //! Adapters connecting new block, new transaction, and accepted transaction //! events to consumers of those events. +use std::cmp::min; use std::fs::File; use std::net::SocketAddr; use std::sync::{Arc, Weak}; @@ -369,6 +370,91 @@ impl p2p::ChainAdapter for NetToChainAdapter { true } } + + fn read_kernels(&self, first_block_height: u64) -> Vec<(Hash, Vec)> { + let head = match self.chain().head() { + Ok(head) => head, + Err(e) => { + error!("read_kernels: Could not find head: {:?}", e); + return vec![]; + } + }; + + if first_block_height > head.height { + debug!("read_kernels: Requesting beyond chain height."); + return vec![]; + } + + let mut next_kernel_index = match first_block_height { + 0 => 0, + height => { + let header = match self.chain().get_header_by_height(height - 1) { + Ok(header) => header, + Err(e) => { + error!("read_kernels: Could not find header: {:?}", e); + return vec![]; + } + }; + header.kernel_mmr_size + } + }; + + let mut blocks = vec![]; + let mut index = first_block_height; + while blocks.len() < p2p::MAX_KERNEL_BLOCKS as usize && index <= head.height { + let header = match self.chain().get_header_by_height(first_block_height) { + Ok(header) => header, + Err(e) => { + error!("read_kernels: Could not find header: {:?}", e); + return vec![]; + } + }; + + let kernels_to_read = header.kernel_mmr_size - next_kernel_index + 1; + let kernels: Vec = self + .chain() + .get_kernels_by_insertion_index(next_kernel_index, kernels_to_read) + .iter() + .map(|entry| entry.kernel.clone()) + .collect(); + + blocks.push((header.hash(), kernels)); + next_kernel_index = header.kernel_mmr_size + 1; + index += 1; + } + + debug!("read_kernels: returning kernel blocks: {}", blocks.len()); + + blocks + } + + fn kernels_received( + &self, + blocks: &Vec<(Hash, Vec)>, + peer_addr: SocketAddr, + ) -> bool { + info!( + "Received kernels in blocks {:?} from {}", + blocks.iter().map(|x| x.0).collect::>(), + peer_addr, + ); + + if blocks.len() == 0 { + return false; + } + + // try to add kernels to our kernel MMR + let res = self.chain().sync_kernels(blocks, self.chain_opts()); + if let &Err(ref e) = &res { + debug!("Kernels refused by chain: {:?}", e); + + if e.is_bad_data() { + return false; + } + } + + true + } } impl NetToChainAdapter { diff --git a/servers/src/common/types.rs b/servers/src/common/types.rs index e872498694..b3f6cffd73 100644 --- a/servers/src/common/types.rs +++ b/servers/src/common/types.rs @@ -244,6 +244,11 @@ pub enum SyncStatus { current_height: u64, highest_height: u64, }, + /// Downloading kernels + KernelSync { + current_height: u64, + highest_height: u64, + }, /// Downloading the various txhashsets TxHashsetDownload { start_time: DateTime, diff --git a/servers/src/grin/server.rs b/servers/src/grin/server.rs index 255ce1c412..26f3626a0a 100644 --- a/servers/src/grin/server.rs +++ b/servers/src/grin/server.rs @@ -221,6 +221,7 @@ impl Server { p2p_server.peers.clone(), shared_chain.clone(), stop.clone(), + config.p2p_config.capabilities, ); let p2p_inner = p2p_server.clone(); diff --git a/servers/src/grin/sync/kernel_sync.rs b/servers/src/grin/sync/kernel_sync.rs new file mode 100644 index 0000000000..73da738561 --- /dev/null +++ b/servers/src/grin/sync/kernel_sync.rs @@ -0,0 +1,156 @@ +// Copyright 2018 The Grin Developers +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use chrono::prelude::{DateTime, Utc}; +use chrono::Duration; +use std::sync::Arc; + +use chain; +use chain::Tip; +use common::types::{SyncState, SyncStatus}; +use core::core::hash::Hashed; +use core::core::BlockHeader; +use p2p; +use p2p::types::Capabilities; + +/// Fast sync has 4 "states": +/// * syncing headers +/// * once all headers are sync'd, sync kernels +/// * once kernels are sync'd, requesting the txhashset state +/// * once we have the state, get blocks after that +/// +/// The KernelSync struct implements and monitors the second step. +pub struct KernelSync { + sync_state: Arc, + peers: Arc, + chain: Arc, + capabilities: p2p::Capabilities, + + /// Holds the timeout, num kernels received, and previous num kernels received + /// at the time of the previous kernel sync. + prev_kernel_sync: (DateTime, u64, u64), +} + +impl KernelSync { + pub fn new( + sync_state: Arc, + peers: Arc, + chain: Arc, + capabilities: p2p::Capabilities, + ) -> KernelSync { + KernelSync { + sync_state, + peers, + chain, + capabilities, + prev_kernel_sync: (Utc::now(), 0, 0), + } + } + + /// Check whether kernel sync should run and requests kernels from capable peers. + pub fn check_run(&mut self) -> bool { + let enable_kernel_sync = self + .capabilities + .contains(Capabilities::ENHANCED_TXHASHSET_HIST); + + if enable_kernel_sync { + let header_head = match self.chain.header_head() { + Ok(header_head) => header_head, + Err(e) => { + error!("kernel_sync: check_run err! {:?}", e); + return false; + } + }; + + let kernel_tip = match self.chain.get_kernel_root_validated_tip() { + Ok(kernel_tip) => kernel_tip, + Err(e) => { + error!("kernel_sync: check_run err! {:?}", e); + return false; + } + }; + + if !self.kernel_sync_due(&header_head, &kernel_tip) { + return false; + } + + self.sync_state.update(SyncStatus::KernelSync { + current_height: kernel_tip.height, + highest_height: header_head.height, + }); + + // DAVID: If no capable peer exists, fall back to full txhashset download + self.kernel_sync(kernel_tip.height + 1); + + return true; + } + false + } + + fn kernel_sync_due(&mut self, header_head: &Tip, kernel_tip: &BlockHeader) -> bool { + // Kernels are up to date on the current fork. + if kernel_tip.height + 5 > header_head.height { + return false; + } + + let now = Utc::now(); + let (timeout, last_kernel_blocks_received, prev_kernel_blocks_received) = + self.prev_kernel_sync; + + // received all necessary kernels, can ask for more + let can_request_more = + kernel_tip.height >= prev_kernel_blocks_received + (p2p::MAX_KERNEL_BLOCKS as u64); + + // no kernels processed and we're past timeout, need to ask for more + let stalling = kernel_tip.height <= last_kernel_blocks_received && now > timeout; + + if can_request_more || stalling { + self.prev_kernel_sync = ( + now + Duration::seconds(10), + kernel_tip.height, + kernel_tip.height, + ); + true + } else { + // resetting the timeout as long as we progress + if kernel_tip.height > last_kernel_blocks_received { + self.prev_kernel_sync = ( + now + Duration::seconds(2), + kernel_tip.height, + prev_kernel_blocks_received, + ); + } + false + } + } + + fn kernel_sync(&mut self, first_block_height: u64) -> Result<(), p2p::Error> { + let opt_peer = self.peers.most_work_peers().into_iter().find(|peer| { + peer.info + .capabilities + .contains(Capabilities::ENHANCED_TXHASHSET_HIST) + }); + + if let Some(peer) = opt_peer { + debug!( + "kernel_sync: asking {} for kernels starting at block {:?}", + peer.info.addr, first_block_height + ); + + let _ = peer.send_kernel_request(first_block_height); + return Ok(()); + } + Err(p2p::Error::PeerException) + } +} diff --git a/servers/src/grin/sync/mod.rs b/servers/src/grin/sync/mod.rs index 56c30850f8..9712905acf 100644 --- a/servers/src/grin/sync/mod.rs +++ b/servers/src/grin/sync/mod.rs @@ -16,6 +16,7 @@ mod body_sync; mod header_sync; +mod kernel_sync; mod state_sync; mod syncer; diff --git a/servers/src/grin/sync/state_sync.rs b/servers/src/grin/sync/state_sync.rs index e1c3fd022e..f14fe3493c 100644 --- a/servers/src/grin/sync/state_sync.rs +++ b/servers/src/grin/sync/state_sync.rs @@ -20,18 +20,20 @@ use chain; use common::types::{Error, SyncState, SyncStatus}; use core::core::hash::Hashed; use core::global; -use p2p::{self, Peer}; +use p2p::{self, Capabilities, Peer}; -/// Fast sync has 3 "states": +/// Fast sync has 4 "states": /// * syncing headers -/// * once all headers are sync'd, requesting the txhashset state +/// * once all headers are sync'd, sync kernels +/// * once kernels are sync'd, requesting the txhashset state /// * once we have the state, get blocks after that /// -/// The StateSync struct implements and monitors the middle step. +/// The StateSync struct implements and monitors the third step. pub struct StateSync { sync_state: Arc, peers: Arc, chain: Arc, + capabilities: p2p::Capabilities, prev_state_sync: Option>, state_sync_peer: Option>, @@ -42,11 +44,13 @@ impl StateSync { sync_state: Arc, peers: Arc, chain: Arc, + capabilities: p2p::Capabilities, ) -> StateSync { StateSync { sync_state, peers, chain, + capabilities, prev_state_sync: None, state_sync_peer: None, } @@ -63,7 +67,7 @@ impl StateSync { highest_height: u64, ) -> bool { trace!("state_sync: head.height: {}, tail.height: {}. header_head.height: {}, highest_height: {}", - head.height, tail.height, header_head.height, highest_height, + head.height, tail.height, header_head.height, highest_height ); let mut sync_need_restart = false; @@ -158,7 +162,7 @@ impl StateSync { fn request_state(&self, header_head: &chain::Tip) -> Result, p2p::Error> { let threshold = global::state_sync_threshold() as u64; - if let Some(peer) = self.peers.most_work_peer() { + if let Some(peer) = self.find_peer() { // ask for txhashset at state_sync_threshold let mut txhashset_head = self .chain @@ -184,6 +188,24 @@ impl StateSync { Err(p2p::Error::PeerException) } + fn find_peer(&self) -> Option> { + if self + .capabilities + .contains(Capabilities::ENHANCED_TXHASHSET_HIST) + { + let opt_enhanced_peer = self.peers.most_work_peers().into_iter().find(|peer| { + peer.info + .capabilities + .contains(Capabilities::ENHANCED_TXHASHSET_HIST) + }); + + if let Some(enhanced_peer) = opt_enhanced_peer { + return Some(enhanced_peer); + } + } + self.peers.most_work_peer() + } + // For now this is a one-time thing (it can be slow) at initial startup. fn state_sync_due(&mut self) -> (bool, bool) { let now = Utc::now(); diff --git a/servers/src/grin/sync/syncer.rs b/servers/src/grin/sync/syncer.rs index 3f6496c8ff..e418cc4cbf 100644 --- a/servers/src/grin/sync/syncer.rs +++ b/servers/src/grin/sync/syncer.rs @@ -22,6 +22,7 @@ use common::types::{SyncState, SyncStatus}; use core::pow::Difficulty; use grin::sync::body_sync::BodySync; use grin::sync::header_sync::HeaderSync; +use grin::sync::kernel_sync::KernelSync; use grin::sync::state_sync::StateSync; use p2p; @@ -30,11 +31,12 @@ pub fn run_sync( peers: Arc, chain: Arc, stop: Arc, + capabilities: p2p::Capabilities, ) { let _ = thread::Builder::new() .name("sync".to_string()) .spawn(move || { - let runner = SyncRunner::new(sync_state, peers, chain, stop); + let runner = SyncRunner::new(sync_state, peers, chain, stop, capabilities); runner.sync_loop(); }); } @@ -44,6 +46,7 @@ pub struct SyncRunner { peers: Arc, chain: Arc, stop: Arc, + capabilities: p2p::Capabilities, } impl SyncRunner { @@ -52,12 +55,14 @@ impl SyncRunner { peers: Arc, chain: Arc, stop: Arc, + capabilities: p2p::Capabilities, ) -> SyncRunner { SyncRunner { sync_state, peers, chain, stop, + capabilities, } } @@ -99,12 +104,18 @@ impl SyncRunner { // Wait for connections reach at least MIN_PEERS self.wait_for_min_peers(); - // Our 3 main sync stages + // Our 4 main sync stages let mut header_sync = HeaderSync::new( self.sync_state.clone(), self.peers.clone(), self.chain.clone(), ); + let mut kernel_sync = KernelSync::new( + self.sync_state.clone(), + self.peers.clone(), + self.chain.clone(), + self.capabilities, + ); let mut body_sync = BodySync::new( self.sync_state.clone(), self.peers.clone(), @@ -114,6 +125,7 @@ impl SyncRunner { self.sync_state.clone(), self.peers.clone(), self.chain.clone(), + self.capabilities, ); // Highest height seen on the network, generally useful for a fast test on @@ -168,7 +180,9 @@ impl SyncRunner { } if check_state_sync { - state_sync.check_run(&header_head, &head, &tail, highest_height); + if !kernel_sync.check_run() { + state_sync.check_run(&header_head, &head, &tail, highest_height); + } } } } diff --git a/src/bin/tui/status.rs b/src/bin/tui/status.rs index d123b95209..1e0c0dccf9 100644 --- a/src/bin/tui/status.rs +++ b/src/bin/tui/status.rs @@ -108,7 +108,18 @@ impl TUIStatusListener for TUIStatusView { } else { current_height * 100 / highest_height }; - format!("Downloading headers: {}%, step 1/4", percent) + format!("Downloading headers: {}%, step 1/5", percent) + } + SyncStatus::KernelSync { + current_height, + highest_height, + } => { + let percent = if highest_height == 0 { + 0 + } else { + current_height * 100 / highest_height + }; + format!("Downloading kernels: {}%, step 2/5", percent) } SyncStatus::TxHashsetDownload { start_time, @@ -125,7 +136,7 @@ impl TUIStatusListener for TUIStatusView { let fin = Utc::now().timestamp_nanos(); let dur_ms = (fin - start) as f64 * NANO_TO_MILLIS; - format!("Downloading {}(MB) chain state for state sync: {}% at {:.1?}(kB/s), step 2/4", + format!("Downloading {}(MB) chain state for state sync: {}% at {:.1?}(kB/s), step 3/5", total_size / 1_000_000, percent, if dur_ms > 1.0f64 { downloaded_size as f64 / dur_ms as f64 } else { 0f64 }, @@ -135,13 +146,13 @@ impl TUIStatusListener for TUIStatusView { let fin = Utc::now().timestamp_millis(); let dur_secs = (fin - start) / 1000; - format!("Downloading chain state for state sync. Waiting remote peer to start: {}s, step 2/4", + format!("Downloading chain state for state sync. Waiting remote peer to start: {}s, step 3/5", dur_secs, ) } } SyncStatus::TxHashsetSetup => { - "Preparing chain state for validation, step 3/4".to_string() + "Preparing chain state for validation, step 4/5".to_string() } SyncStatus::TxHashsetValidation { kernels, @@ -161,13 +172,13 @@ impl TUIStatusListener for TUIStatusView { } else { 0 }; - format!("Validating chain state: {}%, step 3/4", percent) + format!("Validating chain state: {}%, step 4/5", percent) } SyncStatus::TxHashsetSave => { - "Finalizing chain state for state sync, step 3/4".to_string() + "Finalizing chain state for state sync, step 4/5".to_string() } SyncStatus::TxHashsetDone => { - "Finalized chain state for state sync, step 3/4".to_string() + "Finalized chain state for state sync, step 4/5".to_string() } SyncStatus::BodySync { current_height, @@ -178,7 +189,7 @@ impl TUIStatusListener for TUIStatusView { } else { current_height * 100 / highest_height }; - format!("Downloading blocks: {}%, step 4/4", percent) + format!("Downloading blocks: {}%, step 5/5", percent) } } }; diff --git a/util/src/zip.rs b/util/src/zip.rs index b453e91c79..36cf0b31b4 100644 --- a/util/src/zip.rs +++ b/util/src/zip.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashSet; use std::fs::{self, File}; /// Wrappers around the `zip-rs` library to compress and decompress zip /// bzip2 archives. @@ -63,7 +64,7 @@ pub fn compress(src_dir: &Path, dst_file: &File) -> ZipResult<()> { } /// Decompress a source file into the provided destination path. -pub fn decompress(src_file: R, dest: &Path) -> ZipResult<()> +pub fn decompress(src_file: R, dest: &Path, skip_subdirs: &HashSet) -> ZipResult<()> where R: io::Read + io::Seek, { @@ -73,6 +74,14 @@ where let mut file = archive.by_index(i)?; let file_path = dest.join(file.name()); + let first = file.name().find('/'); + if let Some(first) = first { + let directory = String::from(&file.name()[0..first]); + if skip_subdirs.contains(&directory) { + continue; + } + } + if (&*file.name()).ends_with('/') { fs::create_dir_all(&file_path)?; } else { diff --git a/util/tests/zip.rs b/util/tests/zip.rs index 65be45da4d..12f0793315 100644 --- a/util/tests/zip.rs +++ b/util/tests/zip.rs @@ -14,6 +14,7 @@ extern crate grin_util as util; +use std::collections::HashSet; use std::fs::{self, File}; use std::io::{self, Write}; use std::path::Path; @@ -38,7 +39,8 @@ fn zip_unzip() { fs::create_dir_all(root.join("./dezipped")).unwrap(); let zip_file = File::open(zip_name).unwrap(); - zip::decompress(zip_file, &root.join("./dezipped")).unwrap(); + let skip_subdirs: HashSet = HashSet::new(); + zip::decompress(zip_file, &root.join("./dezipped"), &skip_subdirs).unwrap(); assert!(root.join("to_zip/foo.txt").is_file()); assert!(root.join("to_zip/bar.txt").is_file());