diff --git a/src/contracts/interfaces/ICrossChainRegistry.sol b/src/contracts/interfaces/ICrossChainRegistry.sol index 59f6a385ec..7434163351 100644 --- a/src/contracts/interfaces/ICrossChainRegistry.sol +++ b/src/contracts/interfaces/ICrossChainRegistry.sol @@ -1,5 +1,5 @@ // SPDX-License-Identifier: BUSL-1.1 -pragma solidity >=0.5.0; +pragma solidity ^0.8.27; import {OperatorSet} from "../libraries/OperatorSetLib.sol"; import "./IOperatorTableCalculator.sol"; @@ -7,6 +7,42 @@ import "./IOperatorTableCalculator.sol"; interface ICrossChainRegistryErrors { /// @notice Thrown when the chainId is invalid error InvalidChainId(); + + /// @notice Thrown when a generation reservation already exists for the operator set + error GenerationReservationAlreadyExists(); + + /// @notice Thrown when a generation reservation does not exist for the operator set + error GenerationReservationDoesNotExist(); + + /// @notice Thrown when the operator table calculator address is invalid + error InvalidOperatorTableCalculator(); + + /// @notice Thrown when a transport destination is already added for the operator set + error TransportDestinationAlreadyAdded(); + + /// @notice Thrown when a transport destination is not found for the operator set + error TransportDestinationNotFound(); + + /// @notice Thrown when a chain ID is already whitelisted + error ChainIDAlreadyWhitelisted(); + + /// @notice Thrown when a chain ID is not whitelisted + error ChainIDNotWhitelisted(); + + /// @notice Thrown when the staleness period is zero + error StalenessPeriodZero(); + + /// @notice Thrown when the operator set is not valid + error InvalidOperatorSet(); + + /// @notice Thrown when the chainIDs array is empty + error EmptyChainIDsArray(); + + /// @notice Thrown when a at least one transport destination is required + error RequireAtLeastOneTransportDestination(); + + /// @notice Thrown when the storage is not cleared + error NeedToDelete(); } interface ICrossChainRegistryTypes { @@ -21,34 +57,47 @@ interface ICrossChainRegistryTypes { } } -interface ICrossChainRegistryEvents { - /// @notice Emitted when a generation reservation is made - event GenerationReservationMade(OperatorSet operatorSet, IOperatorTableCalculator operatorTableCalculator); +interface ICrossChainRegistryEvents is ICrossChainRegistryTypes { + /// @notice Emitted when a generation reservation is created + event GenerationReservationCreated(OperatorSet operatorSet); /// @notice Emitted when a generation reservation is removed - event GenerationReservationRemoved(OperatorSet operatorSet, IOperatorTableCalculator operatorTableCalculator); + event GenerationReservationRemoved(OperatorSet operatorSet); + + /// @notice Emitted when an operatorTableCalculator is set + event OperatorTableCalculatorSet(OperatorSet operatorSet, IOperatorTableCalculator operatorTableCalculator); + + /// @notice Emitted when an operatorSetConfig is set + event OperatorSetConfigSet(OperatorSet operatorSet, OperatorSetConfig config); /// @notice Emitted when a transport destination is added - event TransportDestinationAdded(OperatorSet operatorSet, uint32 chainID); + event TransportDestinationAdded(OperatorSet operatorSet, uint256 chainID); /// @notice Emitted when a transport destination is removed - event TransportDestinationRemoved(OperatorSet operatorSet, uint32 chainID); + event TransportDestinationRemoved(OperatorSet operatorSet, uint256 chainID); /// @notice Emitted when a chainID is added to the whitelist - event ChainIDAddedToWhitelist(uint32 chainID); + event ChainIDAddedToWhitelist(uint256 chainID); /// @notice Emitted when a chainID is removed from the whitelist - event ChainIDRemovedFromWhitelist(uint32 chainID); + event ChainIDRemovedFromWhitelist(uint256 chainID); } -interface ICrossChainRegistry is ICrossChainRegistryErrors, ICrossChainRegistryTypes, ICrossChainRegistryEvents { +interface ICrossChainRegistry is ICrossChainRegistryErrors, ICrossChainRegistryEvents { /** - * @notice Initiates a generation reservation + * @notice Creates a generation reservation * @param operatorSet the operatorSet to make a reservation for * @param operatorTableCalculator the address of the operatorTableCalculator + * @param config the config to set for the operatorSet + * @param chainIDs the chainIDs to add as transport destinations * @dev msg.sender must be UAM permissioned for operatorSet.avs */ - function requestGenerationReservation(OperatorSet calldata operatorSet, address operatorTableCalculator) external; + function createGenerationReservation( + OperatorSet calldata operatorSet, + IOperatorTableCalculator operatorTableCalculator, + OperatorSetConfig calldata config, + uint256[] calldata chainIDs + ) external; /** * @notice Removes a generation reservation for a given operatorSet @@ -60,54 +109,73 @@ interface ICrossChainRegistry is ICrossChainRegistryErrors, ICrossChainRegistryT ) external; /** - * @notice Adds a destination chain to transport to - * @param chainID to add transport to + * @notice Sets the operatorTableCalculator for the operatorSet + * @param operatorSet the operatorSet whose operatorTableCalculator is desired to be set + * @param operatorTableCalculator the contract to call to calculate the operator table * @dev msg.sender must be UAM permissioned for operatorSet.avs + * @dev operatorSet must have an active reservation */ - function addTransportDestination(OperatorSet calldata operatorSet, uint32 chainID) external; + function setOperatorTableCalculator( + OperatorSet calldata operatorSet, + IOperatorTableCalculator operatorTableCalculator + ) external; /** - * @notice Removes a destination chain to transport to - * @param chainID to remove transport to + * @notice Sets the operatorSetConfig for a given operatorSet + * @param operatorSet the operatorSet to set the operatorSetConfig for + * @param config the config to set * @dev msg.sender must be UAM permissioned for operatorSet.avs + * @dev operatorSet must have an active generation reservation */ - function removeTransportDestination(OperatorSet calldata operatorSet, uint32 chainID) external; + function setOperatorSetConfig(OperatorSet calldata operatorSet, OperatorSetConfig calldata config) external; /** - * @notice Sets the operatorTableCalculator for the operatorSet - * @param operatorSet the operatorSet whose operatorTableCalculator is desired to be set - * @param calculator the contract to call to calculate the operator table + * @notice Adds destination chains to transport to + * @param operatorSet the operatorSet to add transport destinations for + * @param chainIDs to add transport to * @dev msg.sender must be UAM permissioned for operatorSet.avs - * @dev operatorSet must have an active reservation + * @dev Will create a transport reservation if one doesn't exist */ - function setOperatorTableCalculator( - OperatorSet calldata operatorSet, - IOperatorTableCalculator calculator - ) external; + function addTransportDestinations(OperatorSet calldata operatorSet, uint256[] calldata chainIDs) external; /** - * @notice Adds a chainID to the whitelist of chainIDs that can be transported to - * @param chainID the chainID to add to the whitelist + * @notice Removes destination chains to transport to + * @param operatorSet the operatorSet to remove transport destinations for + * @param chainIDs to remove transport to + * @dev msg.sender must be UAM permissioned for operatorSet.avs + * @dev Will remove the transport reservation if all destinations are removed + */ + function removeTransportDestinations(OperatorSet calldata operatorSet, uint256[] calldata chainIDs) external; + + /** + * @notice Adds chainIDs to the whitelist of chainIDs that can be transported to + * @param chainIDs the chainIDs to add to the whitelist * @dev msg.sender must be the owner of the CrossChainRegistry */ - function addChainIDToWhitelist( - uint32 chainID + function addChainIDsToWhitelist( + uint256[] calldata chainIDs ) external; /** - * @notice Removes a chainID from the whitelist of chainIDs that can be transported to - * @param chainID the chainID to remove from the whitelist + * @notice Removes chainIDs from the whitelist of chainIDs that can be transported to + * @param chainIDs the chainIDs to remove from the whitelist * @dev msg.sender must be the owner of the CrossChainRegistry */ - function removeChainIDFromWhitelist( - uint32 chainID + function removeChainIDsFromWhitelist( + uint256[] calldata chainIDs ) external; /** - * @notice Gets the list of chains that are supported by the CrossChainRegistry - * @return An array of chainIDs that are supported by the CrossChainRegistry + * + * VIEW FUNCTIONS + * + */ + + /** + * @notice Gets the active generation reservations + * @return An array of operatorSets with active generationReservations */ - function getSupportedChains() external view returns (uint32[] memory); + function getActiveGenerationReservations() external view returns (OperatorSet[] memory); /** * @notice Gets the operatorTableCalculator for a given operatorSet @@ -115,25 +183,51 @@ interface ICrossChainRegistry is ICrossChainRegistryErrors, ICrossChainRegistryT * @return The operatorTableCalculator for the given operatorSet */ function getOperatorTableCalculator( - OperatorSet calldata operatorSet + OperatorSet memory operatorSet ) external view returns (IOperatorTableCalculator); /** - * @notice Gets the active generation reservations - * @return An array of operatorSets with active generationReservations - * @return An array of the corresponding operatorTableCalculators + * @notice Gets the operatorSetConfig for a given operatorSet + * @param operatorSet the operatorSet to get the operatorSetConfig for + * @return The operatorSetConfig for the given operatorSet */ - function getActiveGenerationReservations() - external - view - returns (OperatorSet[] memory, IOperatorTableCalculator[] memory); + function getOperatorSetConfig( + OperatorSet memory operatorSet + ) external view returns (OperatorSetConfig memory); + + /** + * @notice Calculates the operatorTableBytes for a given operatorSet + * @param operatorSet the operatorSet to calculate the operator table for + * @return the encoded operatorTableBytes containing: + * - operatorSet details + * - curve type from KeyRegistrar + * - operator set configuration + * - calculated operator table from the calculator contract + * @dev This function aggregates data from multiple sources for cross-chain transport + */ + function calculateOperatorTableBytes( + OperatorSet calldata operatorSet + ) external view returns (bytes memory); + + /** + * @notice Gets the active transport reservations + * @return An array of operatorSets with active transport reservations + * @return An array of chainIDs that the operatorSet is configured to transport to + */ + function getActiveTransportReservations() external view returns (OperatorSet[] memory, uint256[][] memory); /** * @notice Gets the transport destinations for a given operatorSet * @param operatorSet the operatorSet to get the transport destinations for - * @return An array of chainIDs that are transport destinations for the given operatorSet + * @return An array of chainIDs that the operatorSet is configured to transport to */ function getTransportDestinations( - OperatorSet calldata operatorSet - ) external view returns (uint32[] memory); + OperatorSet memory operatorSet + ) external view returns (uint256[] memory); + + /** + * @notice Gets the list of chains that are supported by the CrossChainRegistry + * @return An array of chainIDs that are supported by the CrossChainRegistry + */ + function getSupportedChains() external view returns (uint256[] memory); } diff --git a/src/contracts/multichain/CrossChainRegistry.sol b/src/contracts/multichain/CrossChainRegistry.sol new file mode 100644 index 0000000000..45d0ef3cdc --- /dev/null +++ b/src/contracts/multichain/CrossChainRegistry.sol @@ -0,0 +1,444 @@ +// SPDX-License-Identifier: BUSL-1.1 +pragma solidity ^0.8.27; + +import "@openzeppelin-upgrades/contracts/proxy/utils/Initializable.sol"; +import "@openzeppelin-upgrades/contracts/access/OwnableUpgradeable.sol"; +import "../mixins/PermissionControllerMixin.sol"; +import "../mixins/SemVerMixin.sol"; +import "../permissions/Pausable.sol"; +import "../interfaces/IKeyRegistrar.sol"; +import "./CrossChainRegistryStorage.sol"; + +/** + * @title CrossChainRegistry + * @author Layr Labs, Inc. + * @notice Implementation contract for managing cross-chain operator set configurations and generation reservations + * @dev Manages operator table calculations, transport destinations, and operator set configurations for cross-chain operations + */ +contract CrossChainRegistry is + Initializable, + OwnableUpgradeable, + Pausable, + CrossChainRegistryStorage, + PermissionControllerMixin, + SemVerMixin +{ + using EnumerableSet for EnumerableSet.Bytes32Set; + using EnumerableSet for EnumerableSet.UintSet; + using OperatorSetLib for OperatorSet; + + /** + * + * MODIFIERS + * + */ + + /** + * @dev Validates that the operator set exists in the AllocationManager + * @param operatorSet The operator set to validate + */ + modifier isValidOperatorSet( + OperatorSet calldata operatorSet + ) { + require(allocationManager.isOperatorSet(operatorSet), InvalidOperatorSet()); + _; + } + + /** + * + * INITIALIZING FUNCTIONS + * + */ + + /** + * @dev Initializes the CrossChainRegistry with immutable dependencies + * @param _allocationManager The allocation manager for operator set validation + * @param _keyRegistrar The key registrar for operator set curve type validation + * @param _permissionController The permission controller for access control + * @param _pauserRegistry The pauser registry for pause functionality + * @param _version The semantic version of the contract + */ + constructor( + IAllocationManager _allocationManager, + IKeyRegistrar _keyRegistrar, + IPermissionController _permissionController, + IPauserRegistry _pauserRegistry, + string memory _version + ) + CrossChainRegistryStorage(_allocationManager, _keyRegistrar) + PermissionControllerMixin(_permissionController) + Pausable(_pauserRegistry) + SemVerMixin(_version) + { + _disableInitializers(); + } + + /** + * @notice Initializes the contract with the initial paused status and owner + * @param initialOwner The initial owner of the contract + * @param initialPausedStatus The initial paused status bitmap + */ + function initialize(address initialOwner, uint256 initialPausedStatus) external initializer { + _transferOwnership(initialOwner); + _setPausedStatus(initialPausedStatus); + } + + /** + * + * EXTERNAL FUNCTIONS + * + */ + + /// @inheritdoc ICrossChainRegistry + function createGenerationReservation( + OperatorSet calldata operatorSet, + IOperatorTableCalculator operatorTableCalculator, + OperatorSetConfig calldata config, + uint256[] calldata chainIDs + ) + external + onlyWhenNotPaused(PAUSED_GENERATION_RESERVATIONS) + checkCanCall(operatorSet.avs) + isValidOperatorSet(operatorSet) + { + // Add to active generation reservations + require(_activeGenerationReservations.add(operatorSet.key()), GenerationReservationAlreadyExists()); + emit GenerationReservationCreated(operatorSet); + + // Set the operator table calculator + _setOperatorTableCalculator(operatorSet, operatorTableCalculator, false); + // Set the operator set config + _setOperatorSetConfig(operatorSet, config, false); + // Add transport destinations + _addTransportDestinations(operatorSet, chainIDs); + } + + /// @inheritdoc ICrossChainRegistry + function removeGenerationReservation( + OperatorSet calldata operatorSet + ) + external + onlyWhenNotPaused(PAUSED_GENERATION_RESERVATIONS) + checkCanCall(operatorSet.avs) + isValidOperatorSet(operatorSet) + { + bytes32 operatorSetKey = operatorSet.key(); + + // Remove from active generation reservations + require(_activeGenerationReservations.remove(operatorSetKey), GenerationReservationDoesNotExist()); + emit GenerationReservationRemoved(operatorSet); + + // Remove the operator table calculator + _setOperatorTableCalculator(operatorSet, IOperatorTableCalculator(address(0)), true); + // Remove the operator set config + _setOperatorSetConfig(operatorSet, OperatorSetConfig(address(0), 0), true); + // Remove all transport destinations + // TODO: This can lead to out of gas errors if there are a lot of transport destinations. + _removeTransportDestinations(operatorSet, _transportDestinations[operatorSetKey].values(), true); + } + + /// @inheritdoc ICrossChainRegistry + function setOperatorTableCalculator( + OperatorSet calldata operatorSet, + IOperatorTableCalculator operatorTableCalculator + ) + external + onlyWhenNotPaused(PAUSED_OPERATOR_TABLE_CALCULATOR) + checkCanCall(operatorSet.avs) + isValidOperatorSet(operatorSet) + { + // Check if generation reservation exists + require(_activeGenerationReservations.contains(operatorSet.key()), GenerationReservationDoesNotExist()); + + // Set the operator table calculator + _setOperatorTableCalculator(operatorSet, operatorTableCalculator, false); + } + + /// @inheritdoc ICrossChainRegistry + function setOperatorSetConfig( + OperatorSet calldata operatorSet, + OperatorSetConfig calldata config + ) + external + onlyWhenNotPaused(PAUSED_OPERATOR_SET_CONFIG) + checkCanCall(operatorSet.avs) + isValidOperatorSet(operatorSet) + { + // Check if generation reservation exists + require(_activeGenerationReservations.contains(operatorSet.key()), GenerationReservationDoesNotExist()); + + // Set the operator set config + _setOperatorSetConfig(operatorSet, config, false); + } + + /// @inheritdoc ICrossChainRegistry + function addTransportDestinations( + OperatorSet calldata operatorSet, + uint256[] calldata chainIDs + ) + external + onlyWhenNotPaused(PAUSED_TRANSPORT_DESTINATIONS) + checkCanCall(operatorSet.avs) + isValidOperatorSet(operatorSet) + { + // Check if generation reservation exists + require(_activeGenerationReservations.contains(operatorSet.key()), GenerationReservationDoesNotExist()); + + _addTransportDestinations(operatorSet, chainIDs); + } + + /// @inheritdoc ICrossChainRegistry + function removeTransportDestinations( + OperatorSet calldata operatorSet, + uint256[] calldata chainIDs + ) + external + onlyWhenNotPaused(PAUSED_TRANSPORT_DESTINATIONS) + checkCanCall(operatorSet.avs) + isValidOperatorSet(operatorSet) + { + // Check if generation reservation exists + require(_activeGenerationReservations.contains(operatorSet.key()), GenerationReservationDoesNotExist()); + + _removeTransportDestinations(operatorSet, chainIDs, false); + } + + /// @inheritdoc ICrossChainRegistry + function addChainIDsToWhitelist( + uint256[] calldata chainIDs + ) external onlyOwner onlyWhenNotPaused(PAUSED_CHAIN_WHITELIST) { + for (uint256 i = 0; i < chainIDs.length; i++) { + uint256 chainID = chainIDs[i]; + + // Validate chainID + require(chainID != 0, InvalidChainId()); + + // Add to whitelist + require(_whitelistedChainIDs.add(chainID), ChainIDAlreadyWhitelisted()); + + emit ChainIDAddedToWhitelist(chainID); + } + } + + /// @inheritdoc ICrossChainRegistry + function removeChainIDsFromWhitelist( + uint256[] calldata chainIDs + ) external onlyOwner onlyWhenNotPaused(PAUSED_CHAIN_WHITELIST) { + for (uint256 i = 0; i < chainIDs.length; i++) { + uint256 chainID = chainIDs[i]; + + // Remove from whitelist + require(_whitelistedChainIDs.remove(chainID), ChainIDNotWhitelisted()); + + emit ChainIDRemovedFromWhitelist(chainID); + } + } + + /** + * + * INTERNAL FUNCTIONS + * + */ + + /** + * @dev Internal function to set the operator table calculator for an operator set + * @param operatorSet The operator set to set the calculator for + * @param operatorTableCalculator The operator table calculator contract + * @param isDelete Whether to delete the operator table calculator + */ + function _setOperatorTableCalculator( + OperatorSet memory operatorSet, + IOperatorTableCalculator operatorTableCalculator, + bool isDelete + ) internal { + if (!isDelete) { + // Validate the operator table calculator + require(address(operatorTableCalculator) != address(0), InvalidOperatorTableCalculator()); + } else { + // Need to delete the operator table calculator + require(address(operatorTableCalculator) == address(0), NeedToDelete()); + } + _operatorTableCalculators[operatorSet.key()] = operatorTableCalculator; + emit OperatorTableCalculatorSet(operatorSet, operatorTableCalculator); + } + + /** + * @dev Internal function to set the operator set config for an operator set + * @param operatorSet The operator set to set the config for + * @param config The operator set config + * @param isDelete Whether to delete the operator set config + */ + function _setOperatorSetConfig( + OperatorSet memory operatorSet, + OperatorSetConfig memory config, + bool isDelete + ) internal { + if (!isDelete) { + // Validate the operator set config + require(config.owner != address(0), InputAddressZero()); + require(config.maxStalenessPeriod != 0, StalenessPeriodZero()); + } else { + // Need to delete the operator set config + require(config.owner == address(0), NeedToDelete()); + require(config.maxStalenessPeriod == 0, NeedToDelete()); + } + _operatorSetConfigs[operatorSet.key()] = config; + emit OperatorSetConfigSet(operatorSet, config); + } + + /** + * @dev Internal function to add transport destinations for an operator set + * @param operatorSet The operator set to add destinations for + * @param chainIDs The chain IDs to add as destinations + */ + function _addTransportDestinations(OperatorSet memory operatorSet, uint256[] memory chainIDs) internal { + // Validate chainIDs array + require(chainIDs.length > 0, EmptyChainIDsArray()); + + bytes32 operatorSetKey = operatorSet.key(); + + for (uint256 i = 0; i < chainIDs.length; i++) { + uint256 chainID = chainIDs[i]; + + // Check if chainID is whitelisted + require(_whitelistedChainIDs.contains(chainID), ChainIDNotWhitelisted()); + + // Add transport destination + require(_transportDestinations[operatorSetKey].add(chainID), TransportDestinationAlreadyAdded()); + + emit TransportDestinationAdded(operatorSet, chainID); + } + } + + /** + * @dev Internal function to remove transport destinations for an operator set + * @param operatorSet The operator set to remove destinations from + * @param chainIDs The chain IDs to remove as destinations + * @param isDelete Whether to delete the transport destinations + */ + function _removeTransportDestinations( + OperatorSet memory operatorSet, + uint256[] memory chainIDs, + bool isDelete + ) internal { + // Validate chainIDs array + require(chainIDs.length > 0, EmptyChainIDsArray()); + + bytes32 operatorSetKey = operatorSet.key(); + + for (uint256 i = 0; i < chainIDs.length; i++) { + uint256 chainID = chainIDs[i]; + + // Remove transport destination + require(_transportDestinations[operatorSetKey].remove(chainID), TransportDestinationNotFound()); + + emit TransportDestinationRemoved(operatorSet, chainID); + } + + // Check final state based on isDelete flag + if (!isDelete) { + // For normal removal, at least one destination should remain + require(_transportDestinations[operatorSetKey].length() > 0, RequireAtLeastOneTransportDestination()); + } else { + // Need to delete the transport destinations + require(_transportDestinations[operatorSetKey].length() == 0, NeedToDelete()); + } + } + + /** + * + * VIEW FUNCTIONS + * + */ + + /// @inheritdoc ICrossChainRegistry + function getActiveGenerationReservations() external view returns (OperatorSet[] memory) { + uint256 length = _activeGenerationReservations.length(); + OperatorSet[] memory operatorSets = new OperatorSet[](length); + + for (uint256 i = 0; i < length; i++) { + bytes32 operatorSetKey = _activeGenerationReservations.at(i); + OperatorSet memory operatorSet = OperatorSetLib.decode(operatorSetKey); + + operatorSets[i] = operatorSet; + } + + return operatorSets; + } + + /// @inheritdoc ICrossChainRegistry + function getOperatorTableCalculator( + OperatorSet memory operatorSet + ) public view returns (IOperatorTableCalculator) { + return _operatorTableCalculators[operatorSet.key()]; + } + + /// @inheritdoc ICrossChainRegistry + function getOperatorSetConfig( + OperatorSet memory operatorSet + ) public view returns (OperatorSetConfig memory) { + return _operatorSetConfigs[operatorSet.key()]; + } + + /// @inheritdoc ICrossChainRegistry + function calculateOperatorTableBytes( + OperatorSet calldata operatorSet + ) external view returns (bytes memory) { + return abi.encode( + operatorSet, + keyRegistrar.getOperatorSetCurveType(operatorSet), + getOperatorSetConfig(operatorSet), + getOperatorTableCalculator(operatorSet).calculateOperatorTableBytes(operatorSet) + ); + } + + /// @inheritdoc ICrossChainRegistry + function getActiveTransportReservations() external view returns (OperatorSet[] memory, uint256[][] memory) { + uint256 length = _activeGenerationReservations.length(); + OperatorSet[] memory operatorSets = new OperatorSet[](length); + uint256[][] memory chainIDs = new uint256[][](length); + + for (uint256 i = 0; i < length; i++) { + bytes32 operatorSetKey = _activeGenerationReservations.at(i); + OperatorSet memory operatorSet = OperatorSetLib.decode(operatorSetKey); + + operatorSets[i] = operatorSet; + chainIDs[i] = getTransportDestinations(operatorSet); + } + + return (operatorSets, chainIDs); + } + + /// @inheritdoc ICrossChainRegistry + function getTransportDestinations( + OperatorSet memory operatorSet + ) public view returns (uint256[] memory) { + EnumerableSet.UintSet storage chainIDs = _transportDestinations[operatorSet.key()]; + uint256 length = chainIDs.length(); + + // Create result array with maximum possible size + uint256[] memory result = new uint256[](length); + uint256 count = 0; + + // Single loop to filter whitelisted chains + for (uint256 i = 0; i < length; i++) { + uint256 chainID = chainIDs.at(i); + if (_whitelistedChainIDs.contains(chainID)) { + result[count] = chainID; + count++; + } + } + + // Resize the array to the actual count using assembly + assembly { + mstore(result, count) + } + // Only return chains that are whitelisted + return result; + } + + /// @inheritdoc ICrossChainRegistry + function getSupportedChains() external view returns (uint256[] memory) { + return _whitelistedChainIDs.values(); + } +} diff --git a/src/contracts/multichain/CrossChainRegistryStorage.sol b/src/contracts/multichain/CrossChainRegistryStorage.sol new file mode 100644 index 0000000000..07fef385bf --- /dev/null +++ b/src/contracts/multichain/CrossChainRegistryStorage.sol @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: BUSL-1.1 +pragma solidity ^0.8.27; + +import "@openzeppelin/contracts/utils/structs/EnumerableSet.sol"; +import "../interfaces/ICrossChainRegistry.sol"; +import "../interfaces/IOperatorTableCalculator.sol"; +import "../interfaces/IAllocationManager.sol"; +import "../interfaces/IKeyRegistrar.sol"; +import "../libraries/OperatorSetLib.sol"; + +/** + * @title CrossChainRegistryStorage + * @author Layr Labs, Inc. + * @notice Storage contract for the CrossChainRegistry, containing all storage variables and immutables + * @dev This abstract contract is designed to be inherited by the CrossChainRegistry implementation + */ +abstract contract CrossChainRegistryStorage is ICrossChainRegistry { + using EnumerableSet for EnumerableSet.Bytes32Set; + using EnumerableSet for EnumerableSet.UintSet; + using OperatorSetLib for OperatorSet; + + // Constants + + /// @dev Index for flag that pauses generation reservations when set + uint8 internal constant PAUSED_GENERATION_RESERVATIONS = 0; + + /// @dev Index for flag that pauses operator table calculator modifications when set + uint8 internal constant PAUSED_OPERATOR_TABLE_CALCULATOR = 1; + + /// @dev Index for flag that pauses operator set config modifications when set + uint8 internal constant PAUSED_OPERATOR_SET_CONFIG = 2; + + /// @dev Index for flag that pauses transport destination modifications when set + uint8 internal constant PAUSED_TRANSPORT_DESTINATIONS = 3; + + /// @dev Index for flag that pauses chain whitelist modifications when set + uint8 internal constant PAUSED_CHAIN_WHITELIST = 4; + + // Immutables + + /// @notice The AllocationManager contract for EigenLayer + IAllocationManager public immutable allocationManager; + + /// @notice The KeyRegistrar contract for EigenLayer + IKeyRegistrar public immutable keyRegistrar; + + // Mutatables + + /// GENERATION RESERVATIONS + + /// @dev Set of operator sets with active generation reservations + EnumerableSet.Bytes32Set internal _activeGenerationReservations; + + /// @dev Mapping from operator set key to operator table calculator for active reservations + mapping(bytes32 operatorSetKey => IOperatorTableCalculator) internal _operatorTableCalculators; + + /// @dev Mapping from operator set key to operator set configuration + mapping(bytes32 operatorSetKey => OperatorSetConfig) internal _operatorSetConfigs; + + /// @dev Mapping from operator set key to set of chain IDs for transport destinations + mapping(bytes32 operatorSetKey => EnumerableSet.UintSet) internal _transportDestinations; + + /// CHAIN WHITELISTING + + /// @dev Set of whitelisted chain IDs that can be used as transport destinations + EnumerableSet.UintSet internal _whitelistedChainIDs; + + // Construction + + constructor(IAllocationManager _allocationManager, IKeyRegistrar _keyRegistrar) { + allocationManager = _allocationManager; + keyRegistrar = _keyRegistrar; + } + + /** + * @dev This empty reserved space is put in place to allow future versions to add new + * variables without shifting down storage in the inheritance chain. + * See https://docs.openzeppelin.com/contracts/4.x/upgradeable#storage_gaps + */ + uint256[43] private __gap; +} diff --git a/src/test/mocks/OperatorTableCalculatorMock.sol b/src/test/mocks/OperatorTableCalculatorMock.sol new file mode 100644 index 0000000000..a80374f32c --- /dev/null +++ b/src/test/mocks/OperatorTableCalculatorMock.sol @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: BUSL-1.1 +pragma solidity ^0.8.27; + +import "src/contracts/interfaces/IOperatorTableCalculator.sol"; +import "src/contracts/libraries/OperatorSetLib.sol"; + +contract OperatorTableCalculatorMock is IOperatorTableCalculator { + using OperatorSetLib for OperatorSet; + + mapping(bytes32 => bytes) internal _operatorTableBytes; + mapping(bytes32 => address[]) internal _operators; + mapping(bytes32 => mapping(address => uint)) internal _operatorWeights; + + function calculateOperatorTableBytes(OperatorSet memory operatorSet) external view returns (bytes memory) { + return _operatorTableBytes[operatorSet.key()]; + } + + function setOperatorTableBytes(OperatorSet memory operatorSet, bytes memory operatorTableBytes) external { + _operatorTableBytes[operatorSet.key()] = operatorTableBytes; + } + + function getOperatorWeights(OperatorSet calldata operatorSet) + external + view + returns (address[] memory operators, uint[][] memory weights) + { + bytes32 key = operatorSet.key(); + operators = _operators[key]; + + weights = new uint[][](operators.length); + for (uint i = 0; i < operators.length; i++) { + weights[i] = new uint[](1); + weights[i][0] = _operatorWeights[key][operators[i]]; + } + + return (operators, weights); + } + + function getOperatorWeight(OperatorSet calldata operatorSet, address operator) external view returns (uint weight) { + return _operatorWeights[operatorSet.key()][operator]; + } + + // Helper functions for testing + function setOperators(OperatorSet memory operatorSet, address[] memory operators) external { + _operators[operatorSet.key()] = operators; + } + + function setOperatorWeight(OperatorSet memory operatorSet, address operator, uint weight) external { + _operatorWeights[operatorSet.key()][operator] = weight; + } +} diff --git a/src/test/unit/CrossChainRegistryUnit.t.sol b/src/test/unit/CrossChainRegistryUnit.t.sol new file mode 100644 index 0000000000..2b63efd37d --- /dev/null +++ b/src/test/unit/CrossChainRegistryUnit.t.sol @@ -0,0 +1,1051 @@ +// SPDX-License-Identifier: BUSL-1.1 +pragma solidity ^0.8.27; + +import "src/contracts/multichain/CrossChainRegistry.sol"; +import "src/test/utils/EigenLayerMultichainUnitTestSetup.sol"; +import "src/test/mocks/OperatorTableCalculatorMock.sol"; +import "src/contracts/interfaces/IKeyRegistrar.sol"; + +/** + * @title CrossChainRegistryUnitTests + * @notice Base contract for all CrossChainRegistry unit tests + */ +contract CrossChainRegistryUnitTests is + EigenLayerMultichainUnitTestSetup, + ICrossChainRegistryErrors, + ICrossChainRegistryTypes, + ICrossChainRegistryEvents, + IPermissionControllerErrors, + IKeyRegistrarTypes +{ + // Constants from CrossChainRegistryStorage + uint8 constant PAUSED_GENERATION_RESERVATIONS = 0; + uint8 constant PAUSED_OPERATOR_TABLE_CALCULATOR = 1; + uint8 constant PAUSED_OPERATOR_SET_CONFIG = 2; + uint8 constant PAUSED_TRANSPORT_DESTINATIONS = 3; + uint8 constant PAUSED_CHAIN_WHITELIST = 4; + + // Test state variables + address defaultAVS; + address notPermissioned = address(0xDEAD); + OperatorSet defaultOperatorSet; + OperatorTableCalculatorMock defaultCalculator; + OperatorSetConfig defaultConfig; + uint[] defaultChainIDs; + uint[] emptyChainIDs; + + function setUp() public virtual override { + EigenLayerMultichainUnitTestSetup.setUp(); + + // Set up default test values + defaultAVS = cheats.randomAddress(); + defaultOperatorSet = OperatorSet({avs: defaultAVS, id: 1}); + + defaultCalculator = new OperatorTableCalculatorMock(); + defaultConfig = OperatorSetConfig({owner: cheats.randomAddress(), maxStalenessPeriod: 1 days}); + + defaultChainIDs = new uint[](2); + defaultChainIDs[0] = 1; + defaultChainIDs[1] = 10; + + // Setup default permissions + _grantUAMRole(address(this), defaultAVS); + + // Make the operator set valid in AllocationManager + allocationManagerMock.setIsOperatorSet(defaultOperatorSet, true); + + // Whitelist chain IDs + crossChainRegistry.addChainIDsToWhitelist(defaultChainIDs); + } + + // Helper functions + function _grantUAMRole(address target, address avs) internal { + // Grant admin role first + cheats.prank(avs); + permissionController.addPendingAdmin(avs, avs); + cheats.prank(avs); + permissionController.acceptAdmin(avs); + + // Set appointee for all CrossChainRegistry functions + cheats.startPrank(avs); + permissionController.setAppointee(avs, target, address(crossChainRegistry), crossChainRegistry.createGenerationReservation.selector); + permissionController.setAppointee(avs, target, address(crossChainRegistry), crossChainRegistry.removeGenerationReservation.selector); + permissionController.setAppointee(avs, target, address(crossChainRegistry), crossChainRegistry.setOperatorTableCalculator.selector); + permissionController.setAppointee(avs, target, address(crossChainRegistry), crossChainRegistry.setOperatorSetConfig.selector); + permissionController.setAppointee(avs, target, address(crossChainRegistry), crossChainRegistry.addTransportDestinations.selector); + permissionController.setAppointee(avs, target, address(crossChainRegistry), crossChainRegistry.removeTransportDestinations.selector); + + // Set appointee for KeyRegistrar functions + permissionController.setAppointee(avs, target, address(keyRegistrar), keyRegistrar.configureOperatorSet.selector); + cheats.stopPrank(); + } + + function _createOperatorSet(address avs, uint32 operatorSetId) internal pure returns (OperatorSet memory) { + return OperatorSet({avs: avs, id: operatorSetId}); + } + + function _createOperatorSetConfig(address owner, uint32 stalenessPeriod) internal pure returns (OperatorSetConfig memory) { + return OperatorSetConfig({owner: owner, maxStalenessPeriod: stalenessPeriod}); + } + + function _createAndWhitelistChainIDs(uint count) internal returns (uint[] memory) { + uint[] memory chainIDs = new uint[](count); + for (uint i = 0; i < count; i++) { + chainIDs[i] = 100 + i; + } + crossChainRegistry.addChainIDsToWhitelist(chainIDs); + return chainIDs; + } +} + +/** + * @title CrossChainRegistryUnitTests_initialize + * @notice Unit tests for CrossChainRegistry.initialize + */ +contract CrossChainRegistryUnitTests_initialize is CrossChainRegistryUnitTests { + function test_initialize_AlreadyInitialized() public { + cheats.expectRevert("Initializable: contract is already initialized"); + crossChainRegistry.initialize(address(this), 0); + } + + function test_initialize_CorrectOwnerAndPausedStatus() public { + // Deploy new implementation and proxy to test initialization + CrossChainRegistry freshImplementation = new CrossChainRegistry( + IAllocationManager(address(allocationManagerMock)), + IKeyRegistrar(address(keyRegistrar)), + IPermissionController(address(permissionController)), + pauserRegistry, + "1.0.0" + ); + + address newOwner = cheats.randomAddress(); + uint initialPausedStatus = (1 << PAUSED_GENERATION_RESERVATIONS) | (1 << PAUSED_TRANSPORT_DESTINATIONS); + + CrossChainRegistry freshRegistry = CrossChainRegistry( + address( + new TransparentUpgradeableProxy( + address(freshImplementation), + address(eigenLayerProxyAdmin), + abi.encodeWithSelector(CrossChainRegistry.initialize.selector, newOwner, initialPausedStatus) + ) + ) + ); + + assertEq(freshRegistry.owner(), newOwner, "Owner not set correctly"); + assertTrue(freshRegistry.paused(PAUSED_GENERATION_RESERVATIONS), "PAUSED_GENERATION_RESERVATIONS not set"); + assertTrue(freshRegistry.paused(PAUSED_TRANSPORT_DESTINATIONS), "PAUSED_TRANSPORT_DESTINATIONS not set"); + assertFalse(freshRegistry.paused(PAUSED_OPERATOR_TABLE_CALCULATOR), "PAUSED_OPERATOR_TABLE_CALCULATOR should not be set"); + } +} + +/** + * @title CrossChainRegistryUnitTests_createGenerationReservation + * @notice Unit tests for CrossChainRegistry.createGenerationReservation + */ +contract CrossChainRegistryUnitTests_createGenerationReservation is CrossChainRegistryUnitTests { + function test_Revert_Paused() public { + cheats.prank(pauser); + crossChainRegistry.pause(1 << PAUSED_GENERATION_RESERVATIONS); + + cheats.expectRevert(IPausable.CurrentlyPaused.selector); + crossChainRegistry.createGenerationReservation(defaultOperatorSet, defaultCalculator, defaultConfig, defaultChainIDs); + } + + function test_Revert_NotPermissioned() public { + cheats.prank(notPermissioned); + cheats.expectRevert(PermissionControllerMixin.InvalidPermissions.selector); + crossChainRegistry.createGenerationReservation(defaultOperatorSet, defaultCalculator, defaultConfig, defaultChainIDs); + } + + function test_Revert_InvalidOperatorSet() public { + OperatorSet memory invalidOperatorSet = _createOperatorSet(cheats.randomAddress(), 999); + + // Grant permission for the invalid operator set's AVS + _grantUAMRole(address(this), invalidOperatorSet.avs); + + cheats.expectRevert(InvalidOperatorSet.selector); + crossChainRegistry.createGenerationReservation(invalidOperatorSet, defaultCalculator, defaultConfig, defaultChainIDs); + } + + function test_Revert_EmptyChainIDs() public { + cheats.expectRevert(EmptyChainIDsArray.selector); + crossChainRegistry.createGenerationReservation(defaultOperatorSet, defaultCalculator, defaultConfig, emptyChainIDs); + } + + function test_Revert_GenerationReservationAlreadyExists() public { + crossChainRegistry.createGenerationReservation(defaultOperatorSet, defaultCalculator, defaultConfig, defaultChainIDs); + + cheats.expectRevert(GenerationReservationAlreadyExists.selector); + crossChainRegistry.createGenerationReservation(defaultOperatorSet, defaultCalculator, defaultConfig, defaultChainIDs); + } + + function test_Revert_InvalidOperatorTableCalculator() public { + cheats.expectRevert(InvalidOperatorTableCalculator.selector); + crossChainRegistry.createGenerationReservation( + defaultOperatorSet, IOperatorTableCalculator(address(0)), defaultConfig, defaultChainIDs + ); + } + + function test_Revert_InvalidOperatorSetConfig_ZeroOwner() public { + OperatorSetConfig memory invalidConfig = _createOperatorSetConfig(address(0), 1 days); + + cheats.expectRevert(IPausable.InputAddressZero.selector); + crossChainRegistry.createGenerationReservation(defaultOperatorSet, defaultCalculator, invalidConfig, defaultChainIDs); + } + + function test_Revert_InvalidOperatorSetConfig_ZeroStalenessPeriod() public { + OperatorSetConfig memory invalidConfig = _createOperatorSetConfig(cheats.randomAddress(), 0); + + cheats.expectRevert(StalenessPeriodZero.selector); + crossChainRegistry.createGenerationReservation(defaultOperatorSet, defaultCalculator, invalidConfig, defaultChainIDs); + } + + function test_Revert_ChainIDNotWhitelisted() public { + uint[] memory nonWhitelistedChains = new uint[](1); + nonWhitelistedChains[0] = 999; + + cheats.expectRevert(ChainIDNotWhitelisted.selector); + crossChainRegistry.createGenerationReservation(defaultOperatorSet, defaultCalculator, defaultConfig, nonWhitelistedChains); + } + + function test_createGenerationReservation_Success() public { + // Expect events + cheats.expectEmit(true, true, true, true, address(crossChainRegistry)); + emit GenerationReservationCreated(defaultOperatorSet); + + cheats.expectEmit(true, true, true, true, address(crossChainRegistry)); + emit OperatorTableCalculatorSet(defaultOperatorSet, defaultCalculator); + + cheats.expectEmit(true, true, true, true, address(crossChainRegistry)); + emit OperatorSetConfigSet(defaultOperatorSet, defaultConfig); + + for (uint i = 0; i < defaultChainIDs.length; i++) { + cheats.expectEmit(true, true, true, true, address(crossChainRegistry)); + emit TransportDestinationAdded(defaultOperatorSet, defaultChainIDs[i]); + } + + // Make the call + crossChainRegistry.createGenerationReservation(defaultOperatorSet, defaultCalculator, defaultConfig, defaultChainIDs); + + // Verify state + OperatorSet[] memory activeReservations = crossChainRegistry.getActiveGenerationReservations(); + assertEq(activeReservations.length, 1, "Should have 1 active reservation"); + assertEq(activeReservations[0].avs, defaultOperatorSet.avs, "AVS mismatch"); + assertEq(activeReservations[0].id, defaultOperatorSet.id, "OperatorSetId mismatch"); + + assertEq( + address(crossChainRegistry.getOperatorTableCalculator(defaultOperatorSet)), address(defaultCalculator), "Calculator not set" + ); + + OperatorSetConfig memory retrievedConfig = crossChainRegistry.getOperatorSetConfig(defaultOperatorSet); + assertEq(retrievedConfig.owner, defaultConfig.owner, "Config owner mismatch"); + assertEq(retrievedConfig.maxStalenessPeriod, defaultConfig.maxStalenessPeriod, "Config staleness period mismatch"); + + uint[] memory destinations = crossChainRegistry.getTransportDestinations(defaultOperatorSet); + assertEq(destinations.length, defaultChainIDs.length, "Transport destinations length mismatch"); + for (uint i = 0; i < destinations.length; i++) { + assertEq(destinations[i], defaultChainIDs[i], "Transport destination mismatch"); + } + } + + function testFuzz_createGenerationReservation_MultipleChainIDs(uint8 numChainIDs) public { + numChainIDs = uint8(bound(numChainIDs, 1, 10)); + uint[] memory chainIDs = _createAndWhitelistChainIDs(numChainIDs); + + crossChainRegistry.createGenerationReservation(defaultOperatorSet, defaultCalculator, defaultConfig, chainIDs); + + uint[] memory retrievedChainIDs = crossChainRegistry.getTransportDestinations(defaultOperatorSet); + assertEq(retrievedChainIDs.length, chainIDs.length, "Chain IDs length mismatch"); + for (uint i = 0; i < chainIDs.length; i++) { + assertEq(retrievedChainIDs[i], chainIDs[i], "Chain ID mismatch"); + } + } +} + +/** + * @title CrossChainRegistryUnitTests_removeGenerationReservation + * @notice Unit tests for CrossChainRegistry.removeGenerationReservation + */ +contract CrossChainRegistryUnitTests_removeGenerationReservation is CrossChainRegistryUnitTests { + function setUp() public override { + super.setUp(); + // Create a default reservation + crossChainRegistry.createGenerationReservation(defaultOperatorSet, defaultCalculator, defaultConfig, defaultChainIDs); + } + + function test_Revert_Paused() public { + cheats.prank(pauser); + crossChainRegistry.pause(1 << PAUSED_GENERATION_RESERVATIONS); + + cheats.expectRevert(IPausable.CurrentlyPaused.selector); + crossChainRegistry.removeGenerationReservation(defaultOperatorSet); + } + + function test_Revert_NotPermissioned() public { + cheats.prank(notPermissioned); + cheats.expectRevert(PermissionControllerMixin.InvalidPermissions.selector); + crossChainRegistry.removeGenerationReservation(defaultOperatorSet); + } + + function test_Revert_InvalidOperatorSet() public { + OperatorSet memory invalidOperatorSet = _createOperatorSet(cheats.randomAddress(), 999); + + // Grant permission for the invalid operator set's AVS + _grantUAMRole(address(this), invalidOperatorSet.avs); + + cheats.expectRevert(InvalidOperatorSet.selector); + crossChainRegistry.removeGenerationReservation(invalidOperatorSet); + } + + function test_Revert_GenerationReservationDoesNotExist() public { + OperatorSet memory nonExistentOperatorSet = _createOperatorSet(defaultAVS, 999); + allocationManagerMock.setIsOperatorSet(nonExistentOperatorSet, true); + + cheats.expectRevert(GenerationReservationDoesNotExist.selector); + crossChainRegistry.removeGenerationReservation(nonExistentOperatorSet); + } + + function test_removeGenerationReservation_Success() public { + // Expect events + cheats.expectEmit(true, true, true, true, address(crossChainRegistry)); + emit GenerationReservationRemoved(defaultOperatorSet); + + cheats.expectEmit(true, true, true, true, address(crossChainRegistry)); + emit OperatorTableCalculatorSet(defaultOperatorSet, IOperatorTableCalculator(address(0))); + + cheats.expectEmit(true, true, true, true, address(crossChainRegistry)); + emit OperatorSetConfigSet(defaultOperatorSet, OperatorSetConfig(address(0), 0)); + + for (uint i = 0; i < defaultChainIDs.length; i++) { + cheats.expectEmit(true, true, true, true, address(crossChainRegistry)); + emit TransportDestinationRemoved(defaultOperatorSet, defaultChainIDs[i]); + } + + // Remove the reservation + crossChainRegistry.removeGenerationReservation(defaultOperatorSet); + + // Verify state + OperatorSet[] memory activeReservations = crossChainRegistry.getActiveGenerationReservations(); + assertEq(activeReservations.length, 0, "Should have no active reservations"); + + assertEq(address(crossChainRegistry.getOperatorTableCalculator(defaultOperatorSet)), address(0), "Calculator should be removed"); + + OperatorSetConfig memory retrievedConfig = crossChainRegistry.getOperatorSetConfig(defaultOperatorSet); + assertEq(retrievedConfig.owner, address(0), "Config owner should be zero"); + assertEq(retrievedConfig.maxStalenessPeriod, 0, "Config staleness period should be zero"); + + uint[] memory destinations = crossChainRegistry.getTransportDestinations(defaultOperatorSet); + assertEq(destinations.length, 0, "Should have no transport destinations"); + } +} + +/** + * @title CrossChainRegistryUnitTests_setOperatorTableCalculator + * @notice Unit tests for CrossChainRegistry.setOperatorTableCalculator + */ +contract CrossChainRegistryUnitTests_setOperatorTableCalculator is CrossChainRegistryUnitTests { + OperatorTableCalculatorMock newCalculator; + + function setUp() public override { + super.setUp(); + // Create a default reservation + crossChainRegistry.createGenerationReservation(defaultOperatorSet, defaultCalculator, defaultConfig, defaultChainIDs); + newCalculator = new OperatorTableCalculatorMock(); + } + + function test_Revert_Paused() public { + cheats.prank(pauser); + crossChainRegistry.pause(1 << PAUSED_OPERATOR_TABLE_CALCULATOR); + + cheats.expectRevert(IPausable.CurrentlyPaused.selector); + crossChainRegistry.setOperatorTableCalculator(defaultOperatorSet, newCalculator); + } + + function test_Revert_NotPermissioned() public { + cheats.prank(notPermissioned); + cheats.expectRevert(PermissionControllerMixin.InvalidPermissions.selector); + crossChainRegistry.setOperatorTableCalculator(defaultOperatorSet, newCalculator); + } + + function test_Revert_InvalidOperatorSet() public { + OperatorSet memory invalidOperatorSet = _createOperatorSet(cheats.randomAddress(), 999); + + // Grant permission for the invalid operator set's AVS + _grantUAMRole(address(this), invalidOperatorSet.avs); + + cheats.expectRevert(InvalidOperatorSet.selector); + crossChainRegistry.setOperatorTableCalculator(invalidOperatorSet, newCalculator); + } + + function test_Revert_GenerationReservationDoesNotExist() public { + OperatorSet memory nonExistentOperatorSet = _createOperatorSet(defaultAVS, 999); + allocationManagerMock.setIsOperatorSet(nonExistentOperatorSet, true); + + cheats.expectRevert(GenerationReservationDoesNotExist.selector); + crossChainRegistry.setOperatorTableCalculator(nonExistentOperatorSet, newCalculator); + } + + function test_Revert_InvalidOperatorTableCalculator() public { + cheats.expectRevert(InvalidOperatorTableCalculator.selector); + crossChainRegistry.setOperatorTableCalculator(defaultOperatorSet, IOperatorTableCalculator(address(0))); + } + + function test_setOperatorTableCalculator_Success() public { + // Expect event + cheats.expectEmit(true, true, true, true, address(crossChainRegistry)); + emit OperatorTableCalculatorSet(defaultOperatorSet, newCalculator); + + // Set new calculator + crossChainRegistry.setOperatorTableCalculator(defaultOperatorSet, newCalculator); + + // Verify state + assertEq( + address(crossChainRegistry.getOperatorTableCalculator(defaultOperatorSet)), address(newCalculator), "Calculator not updated" + ); + } + + function testFuzz_setOperatorTableCalculator_MultipleUpdates(uint8 numUpdates) public { + numUpdates = uint8(bound(numUpdates, 1, 10)); + + for (uint i = 0; i < numUpdates; i++) { + OperatorTableCalculatorMock calc = new OperatorTableCalculatorMock(); + crossChainRegistry.setOperatorTableCalculator(defaultOperatorSet, calc); + assertEq(address(crossChainRegistry.getOperatorTableCalculator(defaultOperatorSet)), address(calc), "Calculator not updated"); + } + } +} + +/** + * @title CrossChainRegistryUnitTests_setOperatorSetConfig + * @notice Unit tests for CrossChainRegistry.setOperatorSetConfig + */ +contract CrossChainRegistryUnitTests_setOperatorSetConfig is CrossChainRegistryUnitTests { + OperatorSetConfig newConfig; + + function setUp() public override { + super.setUp(); + // Create a default reservation + crossChainRegistry.createGenerationReservation(defaultOperatorSet, defaultCalculator, defaultConfig, defaultChainIDs); + newConfig = _createOperatorSetConfig(cheats.randomAddress(), 2 days); + } + + function test_Revert_Paused() public { + cheats.prank(pauser); + crossChainRegistry.pause(1 << PAUSED_OPERATOR_SET_CONFIG); + + cheats.expectRevert(IPausable.CurrentlyPaused.selector); + crossChainRegistry.setOperatorSetConfig(defaultOperatorSet, newConfig); + } + + function test_Revert_NotPermissioned() public { + cheats.prank(notPermissioned); + cheats.expectRevert(PermissionControllerMixin.InvalidPermissions.selector); + crossChainRegistry.setOperatorSetConfig(defaultOperatorSet, newConfig); + } + + function test_Revert_InvalidOperatorSet() public { + OperatorSet memory invalidOperatorSet = _createOperatorSet(cheats.randomAddress(), 999); + + // Grant permission for the invalid operator set's AVS + _grantUAMRole(address(this), invalidOperatorSet.avs); + + cheats.expectRevert(InvalidOperatorSet.selector); + crossChainRegistry.setOperatorSetConfig(invalidOperatorSet, newConfig); + } + + function test_Revert_GenerationReservationDoesNotExist() public { + OperatorSet memory nonExistentOperatorSet = _createOperatorSet(defaultAVS, 999); + allocationManagerMock.setIsOperatorSet(nonExistentOperatorSet, true); + + cheats.expectRevert(GenerationReservationDoesNotExist.selector); + crossChainRegistry.setOperatorSetConfig(nonExistentOperatorSet, newConfig); + } + + function test_Revert_InvalidOperatorSetConfig_ZeroOwner() public { + OperatorSetConfig memory invalidConfig = _createOperatorSetConfig(address(0), 1 days); + + cheats.expectRevert(IPausable.InputAddressZero.selector); + crossChainRegistry.setOperatorSetConfig(defaultOperatorSet, invalidConfig); + } + + function test_Revert_InvalidOperatorSetConfig_ZeroStalenessPeriod() public { + OperatorSetConfig memory invalidConfig = _createOperatorSetConfig(cheats.randomAddress(), 0); + + cheats.expectRevert(StalenessPeriodZero.selector); + crossChainRegistry.setOperatorSetConfig(defaultOperatorSet, invalidConfig); + } + + function test_setOperatorSetConfig_Success() public { + // Expect event + cheats.expectEmit(true, true, true, true, address(crossChainRegistry)); + emit OperatorSetConfigSet(defaultOperatorSet, newConfig); + + // Set new config + crossChainRegistry.setOperatorSetConfig(defaultOperatorSet, newConfig); + + // Verify state + OperatorSetConfig memory retrievedConfig = crossChainRegistry.getOperatorSetConfig(defaultOperatorSet); + assertEq(retrievedConfig.owner, newConfig.owner, "Config owner not updated"); + assertEq(retrievedConfig.maxStalenessPeriod, newConfig.maxStalenessPeriod, "Config staleness period not updated"); + } + + function testFuzz_setOperatorSetConfig_StalenessPeriod(uint32 stalenessPeriod) public { + stalenessPeriod = uint32(bound(stalenessPeriod, 1, 365 days)); + OperatorSetConfig memory fuzzConfig = _createOperatorSetConfig(cheats.randomAddress(), stalenessPeriod); + + crossChainRegistry.setOperatorSetConfig(defaultOperatorSet, fuzzConfig); + + OperatorSetConfig memory retrievedConfig = crossChainRegistry.getOperatorSetConfig(defaultOperatorSet); + assertEq(retrievedConfig.maxStalenessPeriod, stalenessPeriod, "Staleness period not set correctly"); + } +} + +/** + * @title CrossChainRegistryUnitTests_addTransportDestinations + * @notice Unit tests for CrossChainRegistry.addTransportDestinations + */ +contract CrossChainRegistryUnitTests_addTransportDestinations is CrossChainRegistryUnitTests { + uint[] newChainIDs; + + function setUp() public override { + super.setUp(); + // Create a default reservation + crossChainRegistry.createGenerationReservation(defaultOperatorSet, defaultCalculator, defaultConfig, defaultChainIDs); + + // Setup new chain IDs to add + newChainIDs = new uint[](2); + newChainIDs[0] = 20; + newChainIDs[1] = 30; + crossChainRegistry.addChainIDsToWhitelist(newChainIDs); + } + + function test_Revert_Paused() public { + cheats.prank(pauser); + crossChainRegistry.pause(1 << PAUSED_TRANSPORT_DESTINATIONS); + + cheats.expectRevert(IPausable.CurrentlyPaused.selector); + crossChainRegistry.addTransportDestinations(defaultOperatorSet, newChainIDs); + } + + function test_Revert_NotPermissioned() public { + cheats.prank(notPermissioned); + cheats.expectRevert(PermissionControllerMixin.InvalidPermissions.selector); + crossChainRegistry.addTransportDestinations(defaultOperatorSet, newChainIDs); + } + + function test_Revert_InvalidOperatorSet() public { + OperatorSet memory invalidOperatorSet = _createOperatorSet(cheats.randomAddress(), 999); + + // Grant permission for the invalid operator set's AVS + _grantUAMRole(address(this), invalidOperatorSet.avs); + + cheats.expectRevert(InvalidOperatorSet.selector); + crossChainRegistry.addTransportDestinations(invalidOperatorSet, newChainIDs); + } + + function test_Revert_EmptyChainIDs() public { + cheats.expectRevert(EmptyChainIDsArray.selector); + crossChainRegistry.addTransportDestinations(defaultOperatorSet, emptyChainIDs); + } + + function test_Revert_GenerationReservationDoesNotExist() public { + OperatorSet memory nonExistentOperatorSet = _createOperatorSet(defaultAVS, 999); + allocationManagerMock.setIsOperatorSet(nonExistentOperatorSet, true); + + cheats.expectRevert(GenerationReservationDoesNotExist.selector); + crossChainRegistry.addTransportDestinations(nonExistentOperatorSet, newChainIDs); + } + + function test_Revert_ChainIDNotWhitelisted() public { + uint[] memory nonWhitelistedChains = new uint[](1); + nonWhitelistedChains[0] = 999; + + cheats.expectRevert(ChainIDNotWhitelisted.selector); + crossChainRegistry.addTransportDestinations(defaultOperatorSet, nonWhitelistedChains); + } + + function test_Revert_TransportDestinationAlreadyAdded() public { + cheats.expectRevert(TransportDestinationAlreadyAdded.selector); + crossChainRegistry.addTransportDestinations(defaultOperatorSet, defaultChainIDs); + } + + function test_addTransportDestinations_Success() public { + // Expect events + for (uint i = 0; i < newChainIDs.length; i++) { + cheats.expectEmit(true, true, true, true, address(crossChainRegistry)); + emit TransportDestinationAdded(defaultOperatorSet, newChainIDs[i]); + } + + // Add new destinations + crossChainRegistry.addTransportDestinations(defaultOperatorSet, newChainIDs); + + // Verify state + uint[] memory destinations = crossChainRegistry.getTransportDestinations(defaultOperatorSet); + assertEq(destinations.length, defaultChainIDs.length + newChainIDs.length, "Destinations count mismatch"); + + // Check all destinations are present + bool[] memory found = new bool[](defaultChainIDs.length + newChainIDs.length); + for (uint i = 0; i < destinations.length; i++) { + for (uint j = 0; j < defaultChainIDs.length; j++) { + if (destinations[i] == defaultChainIDs[j]) found[j] = true; + } + for (uint j = 0; j < newChainIDs.length; j++) { + if (destinations[i] == newChainIDs[j]) found[defaultChainIDs.length + j] = true; + } + } + for (uint i = 0; i < found.length; i++) { + assertTrue(found[i], "Chain ID not found"); + } + } + + function testFuzz_addTransportDestinations_MultipleChainIDs(uint8 numNewChainIDs) public { + numNewChainIDs = uint8(bound(numNewChainIDs, 1, 10)); + uint[] memory fuzzChainIDs = _createAndWhitelistChainIDs(numNewChainIDs); + + crossChainRegistry.addTransportDestinations(defaultOperatorSet, fuzzChainIDs); + + uint[] memory destinations = crossChainRegistry.getTransportDestinations(defaultOperatorSet); + assertEq(destinations.length, defaultChainIDs.length + fuzzChainIDs.length, "Destinations count mismatch"); + } +} + +/** + * @title CrossChainRegistryUnitTests_removeTransportDestinations + * @notice Unit tests for CrossChainRegistry.removeTransportDestinations + */ +contract CrossChainRegistryUnitTests_removeTransportDestinations is CrossChainRegistryUnitTests { + uint[] chainIDsToRemove; + + function setUp() public override { + super.setUp(); + // Create a default reservation with multiple chain IDs + uint[] memory manyChainIDs = new uint[](4); + manyChainIDs[0] = 1; // Already whitelisted in base setUp + manyChainIDs[1] = 10; // Already whitelisted in base setUp + manyChainIDs[2] = 20; + manyChainIDs[3] = 30; + + // Only whitelist the new chain IDs + uint[] memory newChainIDs = new uint[](2); + newChainIDs[0] = 20; + newChainIDs[1] = 30; + crossChainRegistry.addChainIDsToWhitelist(newChainIDs); + + crossChainRegistry.createGenerationReservation(defaultOperatorSet, defaultCalculator, defaultConfig, manyChainIDs); + + // Setup chain IDs to remove (subset) + chainIDsToRemove = new uint[](2); + chainIDsToRemove[0] = 10; + chainIDsToRemove[1] = 20; + } + + function test_Revert_Paused() public { + cheats.prank(pauser); + crossChainRegistry.pause(1 << PAUSED_TRANSPORT_DESTINATIONS); + + cheats.expectRevert(IPausable.CurrentlyPaused.selector); + crossChainRegistry.removeTransportDestinations(defaultOperatorSet, chainIDsToRemove); + } + + function test_Revert_NotPermissioned() public { + cheats.prank(notPermissioned); + cheats.expectRevert(PermissionControllerMixin.InvalidPermissions.selector); + crossChainRegistry.removeTransportDestinations(defaultOperatorSet, chainIDsToRemove); + } + + function test_Revert_InvalidOperatorSet() public { + OperatorSet memory invalidOperatorSet = _createOperatorSet(cheats.randomAddress(), 999); + + // Grant permission for the invalid operator set's AVS + _grantUAMRole(address(this), invalidOperatorSet.avs); + + cheats.expectRevert(InvalidOperatorSet.selector); + crossChainRegistry.removeTransportDestinations(invalidOperatorSet, chainIDsToRemove); + } + + function test_Revert_EmptyChainIDs() public { + cheats.expectRevert(EmptyChainIDsArray.selector); + crossChainRegistry.removeTransportDestinations(defaultOperatorSet, emptyChainIDs); + } + + function test_Revert_GenerationReservationDoesNotExist() public { + OperatorSet memory nonExistentOperatorSet = _createOperatorSet(defaultAVS, 999); + allocationManagerMock.setIsOperatorSet(nonExistentOperatorSet, true); + + cheats.expectRevert(GenerationReservationDoesNotExist.selector); + crossChainRegistry.removeTransportDestinations(nonExistentOperatorSet, chainIDsToRemove); + } + + function test_Revert_TransportDestinationNotFound() public { + uint[] memory nonExistentChains = new uint[](1); + nonExistentChains[0] = 999; + + cheats.expectRevert(TransportDestinationNotFound.selector); + crossChainRegistry.removeTransportDestinations(defaultOperatorSet, nonExistentChains); + } + + function test_Revert_RequireAtLeastOneTransportDestination() public { + // Get all current destinations + uint[] memory allDestinations = crossChainRegistry.getTransportDestinations(defaultOperatorSet); + + // Try to remove all of them + cheats.expectRevert(RequireAtLeastOneTransportDestination.selector); + crossChainRegistry.removeTransportDestinations(defaultOperatorSet, allDestinations); + } + + function test_removeTransportDestinations_Success() public { + // Expect events + for (uint i = 0; i < chainIDsToRemove.length; i++) { + cheats.expectEmit(true, true, true, true, address(crossChainRegistry)); + emit TransportDestinationRemoved(defaultOperatorSet, chainIDsToRemove[i]); + } + + // Remove destinations + crossChainRegistry.removeTransportDestinations(defaultOperatorSet, chainIDsToRemove); + + // Verify state + uint[] memory remainingDestinations = crossChainRegistry.getTransportDestinations(defaultOperatorSet); + assertEq(remainingDestinations.length, 2, "Should have 2 remaining destinations"); + + // Verify the correct destinations remain (1 and 30) + assertTrue( + (remainingDestinations[0] == 1 && remainingDestinations[1] == 30) + || (remainingDestinations[0] == 30 && remainingDestinations[1] == 1), + "Incorrect remaining destinations" + ); + } + + function testFuzz_removeTransportDestinations_PartialRemoval(uint8 numToRemove) public { + // Setup with many destinations + uint[] memory manyChainIDs = _createAndWhitelistChainIDs(10); + OperatorSet memory fuzzOperatorSet = _createOperatorSet(cheats.randomAddress(), 100); + allocationManagerMock.setIsOperatorSet(fuzzOperatorSet, true); + _grantUAMRole(address(this), fuzzOperatorSet.avs); + + crossChainRegistry.createGenerationReservation(fuzzOperatorSet, defaultCalculator, defaultConfig, manyChainIDs); + + // Remove some but not all + numToRemove = uint8(bound(numToRemove, 1, 9)); // Leave at least one + uint[] memory toRemove = new uint[](numToRemove); + for (uint i = 0; i < numToRemove; i++) { + toRemove[i] = manyChainIDs[i]; + } + + crossChainRegistry.removeTransportDestinations(fuzzOperatorSet, toRemove); + + uint[] memory remaining = crossChainRegistry.getTransportDestinations(fuzzOperatorSet); + assertEq(remaining.length, manyChainIDs.length - numToRemove, "Incorrect remaining count"); + } +} + +/** + * @title CrossChainRegistryUnitTests_addChainIDsToWhitelist + * @notice Unit tests for CrossChainRegistry.addChainIDsToWhitelist + */ +contract CrossChainRegistryUnitTests_addChainIDsToWhitelist is CrossChainRegistryUnitTests { + uint[] newChainIDs; + + function setUp() public override { + super.setUp(); + newChainIDs = new uint[](3); + newChainIDs[0] = 100; + newChainIDs[1] = 200; + newChainIDs[2] = 300; + } + + function test_Revert_NotOwner() public { + cheats.prank(notPermissioned); + cheats.expectRevert("Ownable: caller is not the owner"); + crossChainRegistry.addChainIDsToWhitelist(newChainIDs); + } + + function test_Revert_Paused() public { + cheats.prank(pauser); + crossChainRegistry.pause(1 << PAUSED_CHAIN_WHITELIST); + + cheats.expectRevert(IPausable.CurrentlyPaused.selector); + crossChainRegistry.addChainIDsToWhitelist(newChainIDs); + } + + function test_Revert_InvalidChainId() public { + uint[] memory invalidChainIDs = new uint[](1); + invalidChainIDs[0] = 0; + + cheats.expectRevert(InvalidChainId.selector); + crossChainRegistry.addChainIDsToWhitelist(invalidChainIDs); + } + + function test_Revert_ChainIDAlreadyWhitelisted() public { + cheats.expectRevert(ChainIDAlreadyWhitelisted.selector); + crossChainRegistry.addChainIDsToWhitelist(defaultChainIDs); + } + + function test_addChainIDsToWhitelist_Success() public { + // Expect events + for (uint i = 0; i < newChainIDs.length; i++) { + cheats.expectEmit(true, true, true, true, address(crossChainRegistry)); + emit ChainIDAddedToWhitelist(newChainIDs[i]); + } + + // Add to whitelist + crossChainRegistry.addChainIDsToWhitelist(newChainIDs); + + // Verify state + uint[] memory supportedChains = crossChainRegistry.getSupportedChains(); + assertEq(supportedChains.length, defaultChainIDs.length + newChainIDs.length, "Supported chains count mismatch"); + } + + function testFuzz_addChainIDsToWhitelist_MultipleChainIDs(uint8 numChainIDs) public { + numChainIDs = uint8(bound(numChainIDs, 1, 50)); + uint[] memory fuzzChainIDs = new uint[](numChainIDs); + + for (uint i = 0; i < numChainIDs; i++) { + fuzzChainIDs[i] = 1000 + uint(i); + } + + crossChainRegistry.addChainIDsToWhitelist(fuzzChainIDs); + + uint[] memory supportedChains = crossChainRegistry.getSupportedChains(); + assertTrue(supportedChains.length >= numChainIDs, "Not all chains added"); + } +} + +/** + * @title CrossChainRegistryUnitTests_removeChainIDsFromWhitelist + * @notice Unit tests for CrossChainRegistry.removeChainIDsFromWhitelist + */ +contract CrossChainRegistryUnitTests_removeChainIDsFromWhitelist is CrossChainRegistryUnitTests { + function setUp() public override { + super.setUp(); + } + + function test_Revert_NotOwner() public { + cheats.prank(notPermissioned); + cheats.expectRevert("Ownable: caller is not the owner"); + crossChainRegistry.removeChainIDsFromWhitelist(defaultChainIDs); + } + + function test_Revert_Paused() public { + cheats.prank(pauser); + crossChainRegistry.pause(1 << PAUSED_CHAIN_WHITELIST); + + cheats.expectRevert(IPausable.CurrentlyPaused.selector); + crossChainRegistry.removeChainIDsFromWhitelist(defaultChainIDs); + } + + function test_Revert_ChainIDNotWhitelisted() public { + uint[] memory nonWhitelistedChains = new uint[](1); + nonWhitelistedChains[0] = 999; + + cheats.expectRevert(ChainIDNotWhitelisted.selector); + crossChainRegistry.removeChainIDsFromWhitelist(nonWhitelistedChains); + } + + function test_removeChainIDsFromWhitelist_Success() public { + // Expect events + for (uint i = 0; i < defaultChainIDs.length; i++) { + cheats.expectEmit(true, true, true, true, address(crossChainRegistry)); + emit ChainIDRemovedFromWhitelist(defaultChainIDs[i]); + } + + // Remove from whitelist + crossChainRegistry.removeChainIDsFromWhitelist(defaultChainIDs); + + // Verify state + uint[] memory supportedChains = crossChainRegistry.getSupportedChains(); + assertEq(supportedChains.length, 0, "Should have no supported chains"); + } + + function test_removeChainIDsFromWhitelist_AffectsTransportDestinations() public { + // Create a reservation + crossChainRegistry.createGenerationReservation(defaultOperatorSet, defaultCalculator, defaultConfig, defaultChainIDs); + + // Remove one chain from whitelist + uint[] memory chainToRemove = new uint[](1); + chainToRemove[0] = defaultChainIDs[0]; + crossChainRegistry.removeChainIDsFromWhitelist(chainToRemove); + + // Verify transport destinations only returns whitelisted chains + uint[] memory destinations = crossChainRegistry.getTransportDestinations(defaultOperatorSet); + assertEq(destinations.length, 1, "Should only return whitelisted destination"); + assertEq(destinations[0], defaultChainIDs[1], "Wrong destination returned"); + } +} + +/** + * @title CrossChainRegistryUnitTests_getActiveGenerationReservations + * @notice Unit tests for CrossChainRegistry.getActiveGenerationReservations + */ +contract CrossChainRegistryUnitTests_getActiveGenerationReservations is CrossChainRegistryUnitTests { + function test_getActiveGenerationReservations_Empty() public { + OperatorSet[] memory reservations = crossChainRegistry.getActiveGenerationReservations(); + assertEq(reservations.length, 0, "Should have no reservations"); + } + + function test_getActiveGenerationReservations_Single() public { + crossChainRegistry.createGenerationReservation(defaultOperatorSet, defaultCalculator, defaultConfig, defaultChainIDs); + + OperatorSet[] memory reservations = crossChainRegistry.getActiveGenerationReservations(); + assertEq(reservations.length, 1, "Should have 1 reservation"); + assertEq(reservations[0].avs, defaultOperatorSet.avs, "AVS mismatch"); + assertEq(reservations[0].id, defaultOperatorSet.id, "OperatorSetId mismatch"); + } + + function testFuzz_getActiveGenerationReservations_Multiple(uint8 numReservations) public { + numReservations = uint8(bound(numReservations, 1, 10)); + + for (uint i = 0; i < numReservations; i++) { + OperatorSet memory operatorSet = _createOperatorSet(cheats.randomAddress(), uint32(i)); + allocationManagerMock.setIsOperatorSet(operatorSet, true); + _grantUAMRole(address(this), operatorSet.avs); + + crossChainRegistry.createGenerationReservation(operatorSet, defaultCalculator, defaultConfig, defaultChainIDs); + } + + OperatorSet[] memory reservations = crossChainRegistry.getActiveGenerationReservations(); + assertEq(reservations.length, numReservations, "Reservation count mismatch"); + } +} + +/** + * @title CrossChainRegistryUnitTests_calculateOperatorTableBytes + * @notice Unit tests for CrossChainRegistry.calculateOperatorTableBytes + */ +contract CrossChainRegistryUnitTests_calculateOperatorTableBytes is CrossChainRegistryUnitTests { + bytes testOperatorTableBytes = hex"1234567890"; + + function setUp() public override { + super.setUp(); + // Create a default reservation + crossChainRegistry.createGenerationReservation(defaultOperatorSet, defaultCalculator, defaultConfig, defaultChainIDs); + + // Set up mock data + defaultCalculator.setOperatorTableBytes(defaultOperatorSet, testOperatorTableBytes); + // Configure operator set in KeyRegistrar (permissions already granted in base setUp) + keyRegistrar.configureOperatorSet(defaultOperatorSet, CurveType.BN254); + } + + function test_calculateOperatorTableBytes_Success() public { + bytes memory result = crossChainRegistry.calculateOperatorTableBytes(defaultOperatorSet); + + // Decode the result + ( + OperatorSet memory decodedOperatorSet, + CurveType curveType, + OperatorSetConfig memory decodedConfig, + bytes memory decodedOperatorTableBytes + ) = abi.decode(result, (OperatorSet, CurveType, OperatorSetConfig, bytes)); + + // Verify the decoded data + assertEq(decodedOperatorSet.avs, defaultOperatorSet.avs, "AVS mismatch"); + assertEq(decodedOperatorSet.id, defaultOperatorSet.id, "OperatorSetId mismatch"); + assertTrue(curveType == CurveType.BN254, "CurveType mismatch"); + assertEq(decodedConfig.owner, defaultConfig.owner, "Config owner mismatch"); + assertEq(decodedConfig.maxStalenessPeriod, defaultConfig.maxStalenessPeriod, "Config staleness period mismatch"); + assertEq(decodedOperatorTableBytes, testOperatorTableBytes, "OperatorTableBytes mismatch"); + } + + function test_calculateOperatorTableBytes_NonExistentOperatorSet() public { + OperatorSet memory nonExistentOperatorSet = _createOperatorSet(cheats.randomAddress(), 999); + + // Should revert when trying to call calculateOperatorTableBytes on a null calculator + cheats.expectRevert(); + crossChainRegistry.calculateOperatorTableBytes(nonExistentOperatorSet); + } +} + +/** + * @title CrossChainRegistryUnitTests_getActiveTransportReservations + * @notice Unit tests for CrossChainRegistry.getActiveTransportReservations + */ +contract CrossChainRegistryUnitTests_getActiveTransportReservations is CrossChainRegistryUnitTests { + function test_getActiveTransportReservations_Empty() public { + (OperatorSet[] memory operatorSets, uint[][] memory chainIDs) = crossChainRegistry.getActiveTransportReservations(); + assertEq(operatorSets.length, 0, "Should have no transport reservations"); + assertEq(chainIDs.length, 0, "Should have no chain IDs"); + } + + function test_getActiveTransportReservations_Single() public { + crossChainRegistry.createGenerationReservation(defaultOperatorSet, defaultCalculator, defaultConfig, defaultChainIDs); + + (OperatorSet[] memory operatorSets, uint[][] memory chainIDs) = crossChainRegistry.getActiveTransportReservations(); + assertEq(operatorSets.length, 1, "Should have 1 transport reservation"); + assertEq(operatorSets[0].avs, defaultOperatorSet.avs, "AVS mismatch"); + assertEq(operatorSets[0].id, defaultOperatorSet.id, "OperatorSetId mismatch"); + assertEq(chainIDs[0].length, defaultChainIDs.length, "Chain IDs length mismatch"); + for (uint i = 0; i < chainIDs[0].length; i++) { + assertEq(chainIDs[0][i], defaultChainIDs[i], "Chain ID mismatch"); + } + } + + function test_getActiveTransportReservations_Multiple() public { + uint numReservations = 3; + + for (uint i = 0; i < numReservations; i++) { + OperatorSet memory operatorSet = _createOperatorSet(cheats.randomAddress(), uint32(i)); + allocationManagerMock.setIsOperatorSet(operatorSet, true); + _grantUAMRole(address(this), operatorSet.avs); + + // Create unique chain IDs for each iteration + uint[] memory chainIDs = new uint[](i + 1); + for (uint j = 0; j <= i; j++) { + // Use a formula that ensures unique chainIDs across iterations + chainIDs[j] = 100 + uint(i * 10 + j); + } + crossChainRegistry.addChainIDsToWhitelist(chainIDs); + + crossChainRegistry.createGenerationReservation(operatorSet, defaultCalculator, defaultConfig, chainIDs); + } + + (OperatorSet[] memory operatorSets, uint[][] memory chainIDs) = crossChainRegistry.getActiveTransportReservations(); + assertEq(operatorSets.length, numReservations, "Transport reservation count mismatch"); + + for (uint i = 0; i < numReservations; i++) { + assertEq(chainIDs[i].length, i + 1, "Chain IDs length mismatch for reservation"); + } + } +} + +/** + * @title CrossChainRegistryUnitTests_getSupportedChains + * @notice Unit tests for CrossChainRegistry.getSupportedChains + */ +contract CrossChainRegistryUnitTests_getSupportedChains is CrossChainRegistryUnitTests { + function test_getSupportedChains_Initial() public { + uint[] memory supportedChains = crossChainRegistry.getSupportedChains(); + assertEq(supportedChains.length, defaultChainIDs.length, "Should have default chains"); + for (uint i = 0; i < supportedChains.length; i++) { + bool found = false; + for (uint j = 0; j < defaultChainIDs.length; j++) { + if (supportedChains[i] == defaultChainIDs[j]) { + found = true; + break; + } + } + assertTrue(found, "Chain ID not found in supported chains"); + } + } + + function testFuzz_getSupportedChains_AddAndRemove(uint8 numToAdd, uint8 numToRemove) public { + numToAdd = uint8(bound(numToAdd, 1, 20)); + numToRemove = uint8(bound(numToRemove, 0, defaultChainIDs.length)); + + // Add chains + uint[] memory newChains = _createAndWhitelistChainIDs(numToAdd); + + uint[] memory supportedChains = crossChainRegistry.getSupportedChains(); + assertEq(supportedChains.length, defaultChainIDs.length + numToAdd, "Chain count after add mismatch"); + + // Remove some default chains + if (numToRemove > 0) { + uint[] memory chainsToRemove = new uint[](numToRemove); + for (uint i = 0; i < numToRemove; i++) { + chainsToRemove[i] = defaultChainIDs[i]; + } + crossChainRegistry.removeChainIDsFromWhitelist(chainsToRemove); + + supportedChains = crossChainRegistry.getSupportedChains(); + assertEq(supportedChains.length, defaultChainIDs.length + numToAdd - numToRemove, "Chain count after remove mismatch"); + } + } +} diff --git a/src/test/utils/EigenLayerMultichainUnitTestSetup.sol b/src/test/utils/EigenLayerMultichainUnitTestSetup.sol index c9bccdb1bb..c10ee0a724 100644 --- a/src/test/utils/EigenLayerMultichainUnitTestSetup.sol +++ b/src/test/utils/EigenLayerMultichainUnitTestSetup.sol @@ -5,6 +5,7 @@ import "src/test/utils/EigenLayerUnitTestSetup.sol"; import "src/contracts/permissions/KeyRegistrar.sol"; import "src/test/mocks/BN254CertificateVerifierMock.sol"; import "src/test/mocks/ECDSACertificateVerifierMock.sol"; +import "src/contracts/multichain/CrossChainRegistry.sol"; abstract contract EigenLayerMultichainUnitTestSetup is EigenLayerUnitTestSetup { using StdStyle for *; @@ -17,6 +18,8 @@ abstract contract EigenLayerMultichainUnitTestSetup is EigenLayerUnitTestSetup { /// @dev Mocks BN254CertificateVerifierMock bn254CertificateVerifierMock; ECDSACertificateVerifierMock ecdsaCertificateVerifierMock; + CrossChainRegistry crossChainRegistry; + CrossChainRegistry crossChainRegistryImplementation; function setUp() public virtual override { // Setup Core Mocks @@ -31,8 +34,33 @@ abstract contract EigenLayerMultichainUnitTestSetup is EigenLayerUnitTestSetup { bn254CertificateVerifierMock = new BN254CertificateVerifierMock(); ecdsaCertificateVerifierMock = new ECDSACertificateVerifierMock(); + // Deploy CrossChainRegistry implementation + crossChainRegistryImplementation = new CrossChainRegistry( + IAllocationManager(address(allocationManagerMock)), + IKeyRegistrar(address(keyRegistrar)), + IPermissionController(address(permissionController)), + pauserRegistry, + "1.0.0" + ); + + // Deploy CrossChainRegistry proxy + crossChainRegistry = CrossChainRegistry( + address( + new TransparentUpgradeableProxy( + address(crossChainRegistryImplementation), + address(eigenLayerProxyAdmin), + abi.encodeWithSelector( + CrossChainRegistry.initialize.selector, + address(this), // initial owner + 0 // initial paused status + ) + ) + ) + ); + // Filter out mocks isExcludedFuzzAddress[address(bn254CertificateVerifierMock)] = true; isExcludedFuzzAddress[address(ecdsaCertificateVerifierMock)] = true; + isExcludedFuzzAddress[address(crossChainRegistry)] = true; } }