diff --git a/packages/rollup-contracts/contracts/state-tree/BinaryUtils.sol b/packages/rollup-contracts/contracts/state-tree/BinaryUtils.sol new file mode 100644 index 0000000000000..aae46d8be5581 --- /dev/null +++ b/packages/rollup-contracts/contracts/state-tree/BinaryUtils.sol @@ -0,0 +1,200 @@ +pragma solidity >=0.5.0 <0.6.0; + +/** + MIT License + Original author: chriseth + */ + +import {D} from "./DataTypes.sol"; + +library Utils { + /// combines two labels into one + function combineLabels(D.Label memory prefix, D.Label memory suffix) internal pure returns (D.Label memory combined) { + combined.length = prefix.length + suffix.length; + combined.data = prefix.data | (suffix.data >> prefix.length); + } + + + /// Returns a label containing the longest common prefix of `check` and `label` + /// and a label consisting of the remaining part of `label`. + function splitCommonPrefix(D.Label memory label, D.Label memory check) internal pure returns (D.Label memory prefix, D.Label memory labelSuffix) { + return splitAt(label, commonPrefixLength(check, label)); + } + /// Splits the label at the given position and returns prefix and suffix, + /// i.e. prefix.length == pos and prefix.data . suffix.data == l.data. + function splitAt(D.Label memory l, uint pos) internal pure returns (D.Label memory prefix, D.Label memory suffix) { + require(pos <= l.length, "Asked to split label at position exceeding the label length."); + require(pos <= 256, "Asked to split label at position exceeding 256 bits."); + prefix.length = pos; + if (pos == 0) { + prefix.data = bytes32(0); + } else { + prefix.data = l.data & ~bytes32((uint(1) << (256 - pos)) - 1); + } + suffix.length = l.length - pos; + suffix.data = l.data << pos; + } + /// Returns the length of the longest common prefix of the two labels. + function commonPrefixLength(D.Label memory a, D.Label memory b) internal pure returns (uint prefix) { + uint length = a.length < b.length ? a.length : b.length; + // TODO: This could actually use a "highestBitSet" helper + uint diff = uint(a.data ^ b.data); + uint mask = 1 << 255; + for (; prefix < length; prefix++) + { + if ((mask & diff) != 0) + break; + diff += diff; + } + } + /// Returns the result of removing a prefix of length `prefix` bits from the + /// given label (i.e. shifting its data to the left). + function removePrefix(D.Label memory l, uint prefix) internal pure returns (D.Label memory r) { + require(prefix <= l.length, "Bad lenght"); + r.length = l.length - prefix; + r.data = l.data << prefix; + } + /// Removes the first bit from a label and returns the bit and a + /// label containing the rest of the label (i.e. shifted to the left). + function chopFirstBit(D.Label memory l) internal pure returns (uint firstBit, D.Label memory tail) { + require(l.length > 0, "Empty element"); + return (uint(l.data >> 255), D.Label(l.data << 1, l.length - 1)); + } + /// Returns the first bit set in the bitfield, where the 0th bit + /// is the least significant. + /// Throws if bitfield is zero. + /// More efficient the smaller the result is. + function lowestBitSet(uint bitfield) internal pure returns (uint bit) { + require(bitfield != 0, "Bad bitfield"); + bytes32 bitfieldBytes = bytes32(bitfield); + // First, find the lowest byte set + uint byteSet = 0; + for (; byteSet < 32; byteSet++) { + if (bitfieldBytes[31 - byteSet] != 0) + break; + } + uint singleByte = uint(uint8(bitfieldBytes[31 - byteSet])); + uint mask = 1; + for (bit = 0; bit < 256; bit ++) { + if ((singleByte & mask) != 0) + return 8 * byteSet + bit; + mask += mask; + } + assert(false); + return 0; + } + /// Returns the value of the `bit`th bit inside `bitfield`, where + /// the least significant is the 0th bit. + function bitSet(uint bitfield, uint bit) internal pure returns (uint) { + return (bitfield & (uint(1) << bit)) != 0 ? 1 : 0; + } +} + + +contract UtilsTest { + function test() public pure { + testLowestBitSet(); + testChopFirstBit(); + testRemovePrefix(); + testCommonPrefix(); + testSplitAt(); + testSplitCommonPrefix(); + } + function testLowestBitSet() internal pure { + require(Utils.lowestBitSet(0x123) == 0, "testLowestBitSet 1"); + require(Utils.lowestBitSet(0x124) == 2, "testLowestBitSet 2"); + require(Utils.lowestBitSet(0x11 << 30) == 30, "testLowestBitSet 3"); + require(Utils.lowestBitSet(1 << 255) == 255, "testLowestBitSet 4"); + } + function testChopFirstBit() internal pure { + D.Label memory l; + l.data = hex"ef1230"; + l.length = 20; + uint bit1; + uint bit2; + uint bit3; + uint bit4; + (bit1, l) = Utils.chopFirstBit(l); + (bit2, l) = Utils.chopFirstBit(l); + (bit3, l) = Utils.chopFirstBit(l); + (bit4, l) = Utils.chopFirstBit(l); + require(bit1 == 1, "testChopFirstBit 1"); + require(bit2 == 1, "testChopFirstBit 2"); + require(bit3 == 1, "testChopFirstBit 3"); + require(bit4 == 0, "testChopFirstBit 4"); + require(l.length == 16, "testChopFirstBit 5"); + require(l.data == hex"F123", "testChopFirstBit 6"); + + l.data = hex"80"; + l.length = 1; + (bit1, l) = Utils.chopFirstBit(l); + require(bit1 == 1, "Fail 7"); + require(l.length == 0, "Fail 8"); + require(l.data == 0, "Fail 9"); + } + function testRemovePrefix() internal pure { + D.Label memory l; + l.data = hex"ef1230"; + l.length = 20; + l = Utils.removePrefix(l, 4); + require(l.length == 16, "testRemovePrefix 1"); + require(l.data == hex"f123", "testRemovePrefix 2"); + l = Utils.removePrefix(l, 15); + require(l.length == 1, "testRemovePrefix 3"); + require(l.data == hex"80", "testRemovePrefix 4"); + l = Utils.removePrefix(l, 1); + require(l.length == 0, "testRemovePrefix 5"); + require(l.data == 0, "testRemovePrefix 6"); + } + function testCommonPrefix() internal pure { + D.Label memory a; + D.Label memory b; + a.data = hex"abcd"; + a.length = 16; + b.data = hex"a000"; + b.length = 16; + require(Utils.commonPrefixLength(a, b) == 4, "testCommonPrefix 1"); + + b.length = 0; + require(Utils.commonPrefixLength(a, b) == 0, "testCommonPrefix 2"); + + b.data = hex"bbcd"; + b.length = 16; + require(Utils.commonPrefixLength(a, b) == 3, "testCommonPrefix 3"); + require(Utils.commonPrefixLength(b, b) == b.length, "testCommonPrefix 4"); + } + function testSplitAt() internal pure { + D.Label memory a; + a.data = hex"abcd"; + a.length = 16; + (D.Label memory x, D.Label memory y) = Utils.splitAt(a, 0); + require(x.length == 0, "testSplitAt 1"); + require(y.length == a.length, "testSplitAt 2"); + require(y.data == a.data, "testSplitAt 3"); + + (x, y) = Utils.splitAt(a, 4); + require(x.length == 4, "testSplitAt 4"); + require(x.data == hex"a0", "testSplitAt 5"); + require(y.length == 12, "testSplitAt 6"); + require(y.data == hex"bcd0", "testSplitAt 7"); + + (x, y) = Utils.splitAt(a, 16); + require(y.length == 0, "testSplitAt 8"); + require(x.length == a.length, "testSplitAt 9"); + require(x.data == a.data, "testSplitAt 10"); + } + function testSplitCommonPrefix() internal pure { + D.Label memory a; + D.Label memory b; + a.data = hex"abcd"; + a.length = 16; + b.data = hex"a0f570"; + b.length = 20; + (D.Label memory prefix, D.Label memory suffix) = Utils.splitCommonPrefix(b, a); + require(prefix.length == 4, "testSplitCommonPrefix 1"); + require(prefix.data == hex"a0", "testSplitCommonPrefix 2"); + require(suffix.length == 16, "testSplitCommonPrefix 3"); + require(suffix.data == hex"0f57", "testSplitCommonPrefix 4"); + } +} + diff --git a/packages/rollup-contracts/contracts/state-tree/DataTypes.sol b/packages/rollup-contracts/contracts/state-tree/DataTypes.sol new file mode 100644 index 0000000000000..57001fdea132d --- /dev/null +++ b/packages/rollup-contracts/contracts/state-tree/DataTypes.sol @@ -0,0 +1,22 @@ +pragma solidity >=0.5.0 <0.6.0; + +/** + MIT License + Copyright (c) 2017 chriseth + */ + +library D { + struct Label { + bytes32 data; + uint length; + } + + struct Edge { + bytes32 node; + Label label; + } + + struct Node { + Edge[2] children; + } +} diff --git a/packages/rollup-contracts/contracts/state-tree/FullPatriciaTree.sol b/packages/rollup-contracts/contracts/state-tree/FullPatriciaTree.sol new file mode 100644 index 0000000000000..d8d0187fadbf1 --- /dev/null +++ b/packages/rollup-contracts/contracts/state-tree/FullPatriciaTree.sol @@ -0,0 +1,387 @@ +pragma solidity >=0.5.0 <0.6.0; +pragma experimental ABIEncoderV2; + +import {D} from "./DataTypes.sol"; +import {Utils} from "./BinaryUtils.sol"; + +/** + MIT License + Original author: chriseth + Rewritten by: Wanseob Lim + */ + +library FullPatriciaTree { + struct Tree { + // Mapping of hash of key to value + mapping(bytes32 => bytes) values; + + // Particia tree nodes (hash to decoded contents) + mapping(bytes32 => D.Node) nodes; + // The current root hash, keccak256(node(path_M('')), path_M('')) + bytes32 root; + D.Edge rootEdge; + } + + function get(Tree storage tree, bytes32 key) internal view returns (bytes memory) { + return getValue(tree, _findNode(tree, key)); + } + + function safeGet(Tree storage tree, bytes32 key) internal view returns (bytes memory value) { + bytes32 valueHash = _findNode(tree, key); + require(valueHash != bytes32(0)); + value = getValue(tree, valueHash); + require(valueHash == keccak256(value)); + } + + function doesInclude(Tree storage tree, bytes memory key) internal view returns (bool) { + return doesIncludeHashedKey(tree, keccak256(key)); + } + + function doesIncludeHashedKey(Tree storage tree, bytes32 hashedKey) internal view returns (bool) { + bytes32 valueHash = _findNode(tree, hashedKey); + return (valueHash != bytes32(0)); + } + + function getValue(Tree storage tree, bytes32 valueHash) internal view returns (bytes memory) { + return tree.values[valueHash]; + } + + function getRootHash(Tree storage tree) internal view returns (bytes32) { + return tree.root; + } + + + function getNode(Tree storage tree, bytes32 hash) internal view returns (uint, bytes32, bytes32, uint, bytes32, bytes32) { + D.Node storage n = tree.nodes[hash]; + return ( + n.children[0].label.length, n.children[0].label.data, n.children[0].node, + n.children[1].label.length, n.children[1].label.data, n.children[1].node + ); + } + + function getRootEdge(Tree storage tree) internal view returns (uint, bytes32, bytes32) { + return (tree.rootEdge.label.length, tree.rootEdge.label.data, tree.rootEdge.node); + } + + function edgeHash(D.Edge memory e) internal pure returns (bytes32) { + return keccak256(abi.encode(e.node, e.label.length, e.label.data)); + } + + // Returns the hash of the encoding of a node. + function hash(D.Node memory n) internal pure returns (bytes32) { + return keccak256(abi.encode(edgeHash(n.children[0]), edgeHash(n.children[1]))); + } + + // Returns the Merkle-proof for the given key + // Proof format should be: + // - uint branchMask - bitmask with high bits at the positions in the key + // where we have branch nodes (bit in key denotes direction) + // - bytes32[] hashes - hashes of sibling edges + + function getProof(Tree storage tree, bytes32 key) public view returns (uint branchMask, bytes32[] memory _siblings) { + // We will progressively "eat" into the key from the left as we traverse. + D.Label memory remaining; + // We initialize to the full key + remaining = D.Label(key, 256); + // Keeps track of how much we have "eaten" into the key. + // It should always hold that bitsTraversed + remaining.length == keyLength (256) at the end of each loop iteration + uint bitsTraversed = 0; + // Current edge the traversal is processing. + // Each loop iteration will chose the right or left child as the new current edge. + D.Edge memory currentEdge; + // Start traversal at the root + currentEdge = tree.rootEdge; + + // Proof to return along with branch bitmask + bytes32[256] memory siblings; + uint numSiblings = 0; + while (true) { + // Figure out the common prefix between the current edge and remaning bits in the traversal. + // If the requested key has indeed been set, the current edge should be a prefix of the remaining. + D.Label memory prefix; + D.Label memory suffix; + (prefix, suffix) = Utils.splitCommonPrefix(remaining, currentEdge.label); + require( + prefix.length == currentEdge.label.length, + 'Reached an edge in traversal whose label is not a strict prefix of the remaining part of key. This indicates that the requested key has not been set.' + ); + if (suffix.length == 0) { + // Found a match! + break; + } + // Now that we are traversing this edge, add its length to bitsTraversed. + bitsTraversed += prefix.length; + // The next bit in the key determines whether to branch left or right. + // So, this sets the bitsTraversed'th bit in the branch mask to 1. + branchMask |= uint(1) << (255 - bitsTraversed); + + // As explained in the last line, we traverse left or right based on the next bit. + uint head; + D.Label memory tail; + (head, tail) = Utils.chopFirstBit(suffix); + // head, either 0 or 1, tells us which child edge is to traverse to, and which is the sibling to supply in our proof. + uint siblingIndex = 1 - head; + siblings[numSiblings++] = edgeHash( + tree.nodes[currentEdge.node].children[siblingIndex] + ); + // Now, update the current edge to be processed in next iteration + currentEdge = tree.nodes[currentEdge.node].children[head]; + // Account for having processed another bit by choosing left or right. + bitsTraversed += 1; + remaining = tail; + } + if (numSiblings > 0) + { + _siblings = new bytes32[](numSiblings); + for (uint i = 0; i < numSiblings; i++) + _siblings[i] = siblings[i]; + } + } + /** + * @notice Gets a non inclusion proof for the given key. + * A non inclusion proof is an inclusion proof of a node which would be split if the key were in the tree, but is not split. + * Throws if the given key is actually included. + * @param tree The tree to get an inclusion proof for (handled automatically by library, see FullPatriciaTreeImplementation) + * @param key The key to get an inclusion proof for. + * @return conflictingNodeKeyAsLabel - The key for the conflicting node (i.e. the full prefix that all the conflicting nodes' ancestor keys share) expressed as a label + * @return potentialSiblingValue - The hash for the conflicting node + * @return branchMask - A bitmask containing 1 wherever the key is split (branches into two children) + * @return _siblings - The siblings at each split in the branchmask used to verify inclusion of the conflicting node. + */ + function getNonInclusionProof(Tree storage tree, bytes32 key) internal view returns ( + D.Label memory conflictingNodeKeyAsLabel, + bytes32 conflictingNodeValue, + uint branchMask, + bytes32[] memory _siblings + ){ + uint length; + uint numSiblings; + + D.Label memory cumulativeKeyLabel = D.Label(key, 256); + + // Start from root edge + D.Label memory remainingLabel = cumulativeKeyLabel; + D.Edge memory currentEdge = tree.rootEdge; + bytes32[256] memory siblings; + + while (true) { + D.Label memory prefix; + D.Label memory suffix; + (prefix, suffix) = Utils.splitCommonPrefix(remainingLabel, currentEdge.label); + + // suffix.length == 0 means that the key exists. Thus the length of the suffix should be not zero + require(suffix.length != 0, 'Requested non-inclusion proof, the given key is included'); + + if (prefix.length >= currentEdge.label.length) { + // Partial matched, keep finding + length += prefix.length; + branchMask |= uint(1) << (255 - length); + length += 1; + uint head; + (head, remainingLabel) = Utils.chopFirstBit(suffix); + siblings[numSiblings++] = edgeHash(tree.nodes[currentEdge.node].children[1 - head]); + currentEdge = tree.nodes[currentEdge.node].children[head]; + } else { + // Found the potential sibling. Set data to return + (D.Label memory parentKeyAsLabel, ) = Utils.splitAt(cumulativeKeyLabel, length); + conflictingNodeKeyAsLabel = Utils.combineLabels(parentKeyAsLabel, currentEdge.label); + conflictingNodeValue = currentEdge.node; + break; + } + } + if (numSiblings > 0) + { + _siblings = new bytes32[](numSiblings); + for (uint i = 0; i < numSiblings; i++) + _siblings[i] = siblings[i]; + } + } + + // TODO comment/explain these args + function verifyEdgeInclusionProof( + bytes32 rootHash, + bytes32 edgeCommittment, + D.Label memory fullEdgeLabel, + uint branchMask, + bytes32[] memory siblings + ) internal pure { + // We will progressively "eat" into the label from the right as we hash up to the root. + D.Label memory remaining = fullEdgeLabel; + // We will progressively hash the current edge up with its siblings until we get the root edge. + D.Edge memory currentEdge; + // To start, this is the edge we are verifying so it's the edgeHash which was input + currentEdge.node = edgeCommittment; + // Iterate over each set bit in the branch mask to build parent edges, starting from the right. + for (uint i = 0; branchMask != 0; i++) { + // Find the lowest index nonzero bit in the mask, where rightmost == index 0 + uint bitSet = Utils.lowestBitSet(branchMask); + // Remove from bitmask as we are about to process it + branchMask &= ~(uint(1) << bitSet); + // The label for the current edge is the suffix of the remaining label proceeeding the set bit + (remaining, currentEdge.label) = Utils.splitAt(remaining, 255 - bitSet); // (255 - bitSet) since bitset indexes from the right + // The bitSet'th bit in the key determines whether the sibling is left or right. + uint bit; + // chop this bit off the label, it is implicit in the merkle path so will not be included in a label + (bit, currentEdge.label) = Utils.chopFirstBit(currentEdge.label); + bytes32[2] memory edgeHashes; + edgeHashes[bit] = edgeHash(currentEdge); + edgeHashes[1 - bit] = siblings[siblings.length - i - 1]; + currentEdge.node = keccak256(abi.encode(edgeHashes[0], edgeHashes[1])); + } + // no more branching, so the remaining label is the root edge's label + currentEdge.label = remaining; + require(rootHash == edgeHash(currentEdge), 'Edge inclusion proof verification failed: root hashes do not match.'); + } + + function verifyProof( + bytes32 rootHash, + bytes32 key, + bytes memory value, + uint branchMask, + bytes32[] memory siblings + ) public pure { + // The edge above a leaf commits to the leaf value (i.e. what was actually set) + bytes32 edgeCommittment = keccak256(value); + // The full "label" for a leaf node is the entirety of the key. + D.Label memory fullLabel = D.Label(key, 256); + + verifyEdgeInclusionProof( + rootHash, + edgeCommittment, + fullLabel, + branchMask, + siblings + ); + } + + function verifyNonInclusionProof( + bytes32 rootHash, + bytes32 key, + bytes32 conflictingEdgeKeyData, + uint conflictingEdgeKeyLength, + bytes32 conflictingEdgeCommitment, + uint branchMask, + bytes32[] memory siblings + ) public pure { + // first, verify there is a conflict between the key and given edge + require(conflictingEdgeKeyLength <= 256, 'invalid label specified--exceeds tree depth.'); + D.Label memory conflictingEdgeKeyAsLabel = D.Label(conflictingEdgeKeyData, conflictingEdgeKeyLength); + D.Label memory fullLeafLabel = D.Label(key, 256); + uint indexOfConflict = Utils.commonPrefixLength(conflictingEdgeKeyAsLabel, fullLeafLabel); + bool areConflicting = branchMask & (1 << 255 - indexOfConflict) == 0; + + require(areConflicting, 'The provided conflicting edge is not actually conflicting.'); + verifyEdgeInclusionProof( + rootHash, + conflictingEdgeCommitment, + conflictingEdgeKeyAsLabel, + branchMask, + siblings + ); + } + + function insert(Tree storage tree, bytes32 key, bytes memory value) internal { + D.Label memory k = D.Label(key, 256); + bytes32 valueHash = keccak256(value); + tree.values[valueHash] = value; + D.Edge memory e; + if (tree.rootEdge.node == 0 && tree.rootEdge.label.length == 0) + { + // Empty Trie + e.label = k; + e.node = valueHash; + } + else + { + e = _insertAtEdge(tree, tree.rootEdge, k, valueHash); + } + tree.root = edgeHash(e); + tree.rootEdge = e; + } + + function _insertAtNode( + Tree storage tree, + bytes32 nodeHash, + D.Label memory key, + bytes32 value + ) private returns (bytes32) { + require(key.length > 1, "Bad key"); + D.Node memory n = tree.nodes[nodeHash]; + (uint256 head, D.Label memory tail) = Utils.chopFirstBit(key); + n.children[head] = _insertAtEdge(tree, n.children[head], tail, value); + return _replaceNode(tree, nodeHash, n); + } + + function _insertAtEdge( + Tree storage tree, + D.Edge memory e, + D.Label memory key, bytes32 value + ) private returns (D.Edge memory) { + require(key.length >= e.label.length, "Key lenght mismatch label lenght"); + (D.Label memory prefix, D.Label memory suffix) = Utils.splitCommonPrefix(key, e.label); + bytes32 newNodeHash; + if (suffix.length == 0) { + // Full match with the key, update operation + newNodeHash = value; + } else if (prefix.length >= e.label.length) { + // Partial match, just follow the path + newNodeHash = _insertAtNode(tree, e.node, suffix, value); + } else { + // Mismatch, so let us create a new branch node. + (uint256 head, D.Label memory tail) = Utils.chopFirstBit(suffix); + D.Node memory branchNode; + branchNode.children[head] = D.Edge(value, tail); + branchNode.children[1 - head] = D.Edge(e.node, Utils.removePrefix(e.label, prefix.length + 1)); + newNodeHash = _insertNode(tree, branchNode); + } + return D.Edge(newNodeHash, prefix); + } + + function _insertNode(Tree storage tree, D.Node memory n) private returns (bytes32 newHash) { + bytes32 h = hash(n); + tree.nodes[h].children[0] = n.children[0]; + tree.nodes[h].children[1] = n.children[1]; + return h; + } + + function _replaceNode( + Tree storage tree, + bytes32 oldHash, + D.Node memory n + ) private returns (bytes32 newHash) { + delete tree.nodes[oldHash]; + return _insertNode(tree, n); + } + + function _findNode(Tree storage tree, bytes32 key) private view returns (bytes32) { + if (tree.rootEdge.node == 0 && tree.rootEdge.label.length == 0) { + return 0; + } else { + D.Label memory k = D.Label(key, 256); + return _findAtEdge(tree, tree.rootEdge, k); + } + } + + function _findAtNode(Tree storage tree, bytes32 nodeHash, D.Label memory key) private view returns (bytes32) { + require(key.length > 1); + D.Node memory n = tree.nodes[nodeHash]; + (uint head, D.Label memory tail) = Utils.chopFirstBit(key); + return _findAtEdge(tree, n.children[head], tail); + } + + function _findAtEdge(Tree storage tree, D.Edge memory e, D.Label memory key) private view returns (bytes32){ + require(key.length >= e.label.length); + (D.Label memory prefix, D.Label memory suffix) = Utils.splitCommonPrefix(key, e.label); + if (suffix.length == 0) { + // Full match with the key, update operation + return e.node; + } else if (prefix.length >= e.label.length) { + // Partial match, just follow the path + return _findAtNode(tree, e.node, suffix); + } else { + // Mismatch, return empty bytes + return bytes32(0); + } + } +} + diff --git a/packages/rollup-contracts/contracts/state-tree/FullPatriciaTreeImplementation.sol b/packages/rollup-contracts/contracts/state-tree/FullPatriciaTreeImplementation.sol new file mode 100644 index 0000000000000..88f42febb9a85 --- /dev/null +++ b/packages/rollup-contracts/contracts/state-tree/FullPatriciaTreeImplementation.sol @@ -0,0 +1,76 @@ +pragma solidity >=0.5.0 <0.6.0; +pragma experimental ABIEncoderV2; + +import {D} from "./DataTypes.sol"; +import {FullPatriciaTree} from "./FullPatriciaTree.sol"; + +contract FullPatriciaTreeImplementation { + using FullPatriciaTree for FullPatriciaTree.Tree; + FullPatriciaTree.Tree tree; + + constructor () public { + } + + function insert(bytes32 key, bytes memory value) public { + tree.insert(key, value); + } + + function get(bytes32 key) public view returns (bytes memory) { + return tree.get(key); + } + + function safeGet(bytes32 key) public view returns (bytes memory) { + return tree.safeGet(key); + } + + function doesInclude(bytes memory key) public view returns (bool) { + return tree.doesInclude(key); + } + + function getValue(bytes32 hash) public view returns (bytes memory) { + return tree.values[hash]; + } + + function getRootHash() public view returns (bytes32) { + return tree.getRootHash(); + } + + function getNode(bytes32 hash) public view returns (uint, bytes32, bytes32, uint, bytes32, bytes32) { + return tree.getNode(hash); + } + + function getRootEdge() public view returns (uint, bytes32, bytes32) { + return tree.getRootEdge(); + } + + function getProof(bytes32 key) public view returns (uint branchMask, bytes32[] memory _siblings) { + return tree.getProof(key); + } + + // todo naming -- these arent always leaves + function getNonInclusionProof(bytes32 key) public view returns ( + D.Label memory potentialSiblingCumulativeLabel, + bytes32 potentialSiblingValue, + uint branchMask, + bytes32[] memory _siblings + ) { + return tree.getNonInclusionProof(key); + } + + function verifyProof(bytes32 rootHash, bytes32 key, bytes memory value, uint branchMask, bytes32[] memory siblings) public pure { + FullPatriciaTree.verifyProof(rootHash, key, value, branchMask, siblings); + } + + function verifyNonInclusionProof( + bytes32 rootHash, + bytes32 key, + bytes32 conflictingEdgeFullLabelData, + uint conflictingEdgeFullLabelLength, + bytes32 conflictingEdgeCommitment, + uint branchMask, + bytes32[] memory siblings + ) public pure { + FullPatriciaTree.verifyNonInclusionProof(rootHash, key, conflictingEdgeFullLabelData, conflictingEdgeFullLabelLength, conflictingEdgeCommitment, branchMask, siblings); + } + +} \ No newline at end of file diff --git a/packages/rollup-contracts/test/state-tree/FullPatriciaTree.spec.ts b/packages/rollup-contracts/test/state-tree/FullPatriciaTree.spec.ts new file mode 100644 index 0000000000000..5006c3289d5aa --- /dev/null +++ b/packages/rollup-contracts/test/state-tree/FullPatriciaTree.spec.ts @@ -0,0 +1,213 @@ +import '../setup' + +/* External Imports */ +import { + getLogger, + numberToHexString, + padToLength, + TestUtils, +} from '@eth-optimism/core-utils' +import { + createMockProvider, + deployContract, + getWallets, + link, +} from 'ethereum-waffle' + +/* Logging */ +const log = getLogger('patricia-tree', true) + +/* Contract Imports */ +import * as FullPatriciaTreeImplementation from '../../build/FullPatriciaTreeImplementation.json' +import * as FullPatriciaTreeLibrary from '../../build/FullPatriciaTree.json' +import * as UtilsTest from '../../build/UtilsTest.json' + +const insertSequentialKeys = async ( + treeContract: any, + numKeysToInsert: number, + startingIndex: number = 0 +): Promise> => { + const pairs = [] + for (let i = startingIndex; i < startingIndex + numKeysToInsert; i++) { + const key = padToLength(numberToHexString(i), 32 * 2) + const value = padToLength(numberToHexString(i * 32), 32 * 2) + await treeContract.insert(key, value) + pairs.push({ + key, + value, + }) + } + return pairs +} + +const insertAndVerifySequential = async ( + treeContract: any, + numKeysToInsert: number, + startingIndex: number = 0 +) => { + const KVPairs = await insertSequentialKeys( + treeContract, + numKeysToInsert, + startingIndex + ) + const rootHash = await treeContract.getRootHash() + for (const pair of KVPairs) { + const proof = await treeContract.getProof(pair.key) + await treeContract.verifyProof( + rootHash, + pair.key, + pair.value, + proof.branchMask, + proof._siblings + ) + } +} + +const getAndVerifyNonInclusionProof = async ( + treeContract: any, + key: number +) => { + const keyToUse = padToLength(numberToHexString(key), 32 * 2) + const nonInclusionProof = await treeContract.getNonInclusionProof(keyToUse) + + const conflictingEdgeLabel = nonInclusionProof[0] + const leafNode = nonInclusionProof[1] + const branchMask = nonInclusionProof[2] + const siblings = nonInclusionProof[3] + + const rootHash = await treeContract.getRootHash() + + await treeContract.verifyNonInclusionProof( + rootHash, + keyToUse, + conflictingEdgeLabel[0], + conflictingEdgeLabel[1], + leafNode, + branchMask, + siblings + ) +} + +describe('PatriciaTree (full, non-stateless version)', async () => { + let fullTree + const provider = createMockProvider() + const [wallet1, wallet2] = getWallets(provider) + + before(async () => { + const treeLibrary = await deployContract( + wallet1, + FullPatriciaTreeLibrary, + [] + ) + link( + FullPatriciaTreeImplementation, + 'contracts/state-tree/FullPatriciaTree.sol:FullPatriciaTree', + treeLibrary.address + ) + }) + + beforeEach('Deploy new PatriciaTree', async () => { + fullTree = await deployContract( + wallet1, + FullPatriciaTreeImplementation, + [], + { + gasLimit: 6700000, + } + ) + }) + + describe('Works as a keystore', async () => { + const FOO = + '0x0000000000000000000000000000000000000000000000067320000000000000' + const BAR = + '0x0000000000000000000000000000000004578000000000000000000000000000' + const FUZ = + '0x0000000000000000157800000000000000000000000000000000000000000000' + describe('get()', async () => { + it('should return stored value for the given key', async () => { + await fullTree.insert(FOO, BAR) + const retrieved = await fullTree.get(FOO) + retrieved.should.equal(BAR) + }) + }) + + describe('safeGet()', async () => { + it('should return stored value for the given key', async () => { + await fullTree.insert(FOO, BAR) + const retrieved = await fullTree.safeGet(FOO) + retrieved.should.equal(BAR) + }) + it('should throw if the given key is not included', async () => { + await fullTree.insert(FOO, BAR) + TestUtils.assertThrowsAsync(async () => { + await fullTree.safeGet(FUZ) + }) + }) + }) + }) + + describe('Inclusion proof generation and verification', async () => { + it('should work for the single-key case', async () => { + const key = 150 + const pairs = await insertAndVerifySequential(fullTree, 1, key) + }) + it('should work for the two-key sequential case', async () => { + const startKey = 150 + const pairs = await insertAndVerifySequential(fullTree, 2, startKey) + }) + it('should work for 17-key sequential case', async () => { + const startKey = 150 + const pairs = await insertAndVerifySequential(fullTree, 17, startKey) + }) + it('should work for multiple non-sequential keys', async () => { + const keyToVerify = 18 + await insertSequentialKeys(fullTree, 1, 5) + await insertSequentialKeys(fullTree, 1, 13) + await insertSequentialKeys(fullTree, 1, 27) + await insertSequentialKeys(fullTree, 1, 100000) + await insertSequentialKeys(fullTree, 1, 3000000345) + const pairs = await insertSequentialKeys(fullTree, 1, keyToVerify) + const pair = pairs[0] + const rootHash = await fullTree.getRootHash() + const proof = await fullTree.getProof(pair.key) + await fullTree.verifyProof( + rootHash, + pair.key, + pair.value, + proof.branchMask, + proof._siblings + ) + }) + }) + describe('Non-inclusion proof generation and verification', async () => { + it('Should work for an unset key next to a set one', async () => { + await insertSequentialKeys(fullTree, 1, 0) + await getAndVerifyNonInclusionProof(fullTree, 1) + }) + it('Should work for an unset key between two set ones', async () => { + await insertSequentialKeys(fullTree, 1, 0) + await insertSequentialKeys(fullTree, 1, 2) + await getAndVerifyNonInclusionProof(fullTree, 1) + }) + it('Should work for an unset key next to some set ones', async () => { + await insertSequentialKeys(fullTree, 7, 1) + await getAndVerifyNonInclusionProof(fullTree, 0) + }) + it('Should work for an unset key far away from some set ones', async () => { + await insertSequentialKeys(fullTree, 3, 0) + await insertSequentialKeys(fullTree, 3, 60) + await getAndVerifyNonInclusionProof(fullTree, 17) + }) + }) + + describe('Binary Utils library', async () => { + it('Legacy tests written in Solidity should all pass', async () => { + const testerContract = await deployContract(wallet2, UtilsTest) + await testerContract.test() + }) + }) +})