From 4ab8f9abae955018d37fe0beb60a4dd1d300048d Mon Sep 17 00:00:00 2001
From: danoctavian <danoctavian91@gmail.com>
Date: Sun, 1 Sep 2024 21:32:27 +0300
Subject: [PATCH] add invariant measurement

---
 test/integration/M3/Base.t.sol                | 26 ++++++++++++++
 .../M3/WithdrawalsWithRewards-Scenario.t.sol  | 36 +++++++++++++++++--
 2 files changed, 59 insertions(+), 3 deletions(-)

diff --git a/test/integration/M3/Base.t.sol b/test/integration/M3/Base.t.sol
index 52d214806..4460377fe 100644
--- a/test/integration/M3/Base.t.sol
+++ b/test/integration/M3/Base.t.sol
@@ -27,6 +27,8 @@ import {RewardsDistributor} from "../../../src/RewardsDistributor.sol";
 import {StakingNode} from "../../../src/StakingNode.sol";
 import {WithdrawalQueueManager} from "../../../src/WithdrawalQueueManager.sol";
 import {ynETHRedemptionAssetsVault} from "../../../src/ynETHRedemptionAssetsVault.sol";
+import {IStakingNode} from "../../../src/interfaces/IStakingNodesManager.sol";
+
 
 import "forge-std/console.sol";
 import "forge-std/Test.sol";
@@ -257,4 +259,28 @@ contract Base is Test, Utils {
         vm.prank(actors.ops.VALIDATOR_MANAGER);
         stakingNodesManager.registerValidators(validatorData);
     }
+
+    function runSystemStateInvariants(
+        uint256 previousTotalAssets,
+        uint256 previousTotalSupply,
+        uint256[] memory previousStakingNodeBalances
+    ) public {  
+        assertEq(yneth.totalAssets(), previousTotalAssets, "Total assets integrity check failed");
+        assertEq(yneth.totalSupply(), previousTotalSupply, "Share mint integrity check failed");
+        for (uint i = 0; i < previousStakingNodeBalances.length; i++) {
+            IStakingNode stakingNodeInstance = stakingNodesManager.nodes(i);
+            uint256 currentStakingNodeBalance = stakingNodeInstance.getETHBalance();
+            assertEq(currentStakingNodeBalance, previousStakingNodeBalances[i], "Staking node balance integrity check failed for node ID: ");
+        }
+	}
+
+    function getAllStakingNodeBalances() public view returns (uint256[] memory) {
+        uint256[] memory balances = new uint256[](stakingNodesManager.nodesLength());
+        for (uint256 i = 0; i < stakingNodesManager.nodesLength(); i++) {
+            IStakingNode stakingNode = stakingNodesManager.nodes(i);
+            balances[i] = stakingNode.getETHBalance();
+        }
+        return balances;
+    }
+
 }
diff --git a/test/integration/M3/WithdrawalsWithRewards-Scenario.t.sol b/test/integration/M3/WithdrawalsWithRewards-Scenario.t.sol
index c29c4437d..bc8cf66c6 100644
--- a/test/integration/M3/WithdrawalsWithRewards-Scenario.t.sol
+++ b/test/integration/M3/WithdrawalsWithRewards-Scenario.t.sol
@@ -29,6 +29,15 @@ contract M3WithdrawalsTest is Base {
 
     uint256 public amount;
 
+    struct TestState {
+        uint256 totalAssetsBefore;
+        uint256 totalSupplyBefore;
+        uint256[] stakingNodeBalancesBefore;
+        uint256 previousYnETHRedemptionAssetsVaultBalance;
+        uint256 previousYnETHBalance;
+    }
+
+
     function setUp() public override {
         super.setUp();
     }
@@ -47,6 +56,7 @@ contract M3WithdrawalsTest is Base {
             uint256 userYnETHBalance = yneth.balanceOf(user);
             console.log("User ynETH balance after deposit:", userYnETHBalance);
         }
+
         // create staking node
         {
             vm.prank(actors.ops.STAKING_NODE_CREATOR);
@@ -54,6 +64,15 @@ contract M3WithdrawalsTest is Base {
             nodeId = stakingNodesManager.nodesLength() - 1;
         }
 
+
+        TestState memory state = TestState({
+            totalAssetsBefore: yneth.totalAssets(),
+            totalSupplyBefore: yneth.totalSupply(),
+            stakingNodeBalancesBefore: getAllStakingNodeBalances(),
+            previousYnETHRedemptionAssetsVaultBalance: ynETHRedemptionAssetsVaultInstance.availableRedemptionAssets(),
+            previousYnETHBalance: address(yneth).balance
+        });
+
         // Calculate validator count based on amount
         uint256 validatorCount = amount / 32 ether;
 
@@ -73,6 +92,9 @@ contract M3WithdrawalsTest is Base {
             registerValidators(nodeIds);
         }
 
+        state.stakingNodeBalancesBefore[nodeId] += validatorCount * 32 ether;
+        runSystemStateInvariants(state.totalAssetsBefore, state.totalSupplyBefore, state.stakingNodeBalancesBefore);
+
         // verify withdrawal credentials
         {
 
@@ -91,6 +113,8 @@ contract M3WithdrawalsTest is Base {
             // _testVerifyWithdrawalCredentials();
         }
 
+        runSystemStateInvariants(state.totalAssetsBefore, state.totalSupplyBefore, state.stakingNodeBalancesBefore);
+
         uint256 accumulatedRewards;
         {
             uint256 epochCount = 30;
@@ -102,9 +126,6 @@ contract M3WithdrawalsTest is Base {
             accumulatedRewards += validatorCount * epochCount * 1e9; // 1 GWEI per Epoch per Validator
         }
 
-        // Log accumulated rewards
-        console.log("Accumulated rewards:", accumulatedRewards);
-
 
         // exit validators
         {
@@ -114,6 +135,8 @@ contract M3WithdrawalsTest is Base {
             beaconChain.advanceEpoch();
         }
 
+        runSystemStateInvariants(state.totalAssetsBefore, state.totalSupplyBefore, state.stakingNodeBalancesBefore);
+
 
         // start checkpoint
         {
@@ -122,6 +145,8 @@ contract M3WithdrawalsTest is Base {
             vm.stopPrank();
         }
 
+        runSystemStateInvariants(state.totalAssetsBefore, state.totalSupplyBefore, state.stakingNodeBalancesBefore);
+
         // verify checkpoints
         {
             IStakingNode _node = stakingNodesManager.nodes(nodeId);
@@ -132,6 +157,11 @@ contract M3WithdrawalsTest is Base {
             });
         }
 
+        // Rewards accumulated after verifying the checkpoint
+        state.totalAssetsBefore += accumulatedRewards;
+        state.stakingNodeBalancesBefore[nodeId] += accumulatedRewards;
+        runSystemStateInvariants(state.totalAssetsBefore, state.totalSupplyBefore, state.stakingNodeBalancesBefore);
+
         uint256 withdrawnAmount = 32 ether * validatorIndices.length + accumulatedRewards;
 
         // queue withdrawals