diff --git a/crates/cheatcodes/assets/cheatcodes.json b/crates/cheatcodes/assets/cheatcodes.json index 574d310ca19fc..4517f075e7fb0 100644 --- a/crates/cheatcodes/assets/cheatcodes.json +++ b/crates/cheatcodes/assets/cheatcodes.json @@ -2971,6 +2971,26 @@ "status": "stable", "safety": "safe" }, + { + "func": { + "id": "assumeNoRevert", + "description": "Discard this run's fuzz inputs and generate new ones if next call reverted.", + "declaration": "function assumeNoRevert() external pure;", + "visibility": "external", + "mutability": "pure", + "signature": "assumeNoRevert()", + "selector": "0x285b366a", + "selectorBytes": [ + 40, + 91, + 54, + 106 + ] + }, + "group": "testing", + "status": "stable", + "safety": "safe" + }, { "func": { "id": "blobBaseFee", diff --git a/crates/cheatcodes/spec/src/vm.rs b/crates/cheatcodes/spec/src/vm.rs index 575ccfa84a5a7..980bab066a3a0 100644 --- a/crates/cheatcodes/spec/src/vm.rs +++ b/crates/cheatcodes/spec/src/vm.rs @@ -678,6 +678,10 @@ interface Vm { #[cheatcode(group = Testing, safety = Safe)] function assume(bool condition) external pure; + /// Discard this run's fuzz inputs and generate new ones if next call reverted. + #[cheatcode(group = Testing, safety = Safe)] + function assumeNoRevert() external pure; + /// Writes a breakpoint to jump to in the debugger. #[cheatcode(group = Testing, safety = Safe)] function breakpoint(string calldata char) external; diff --git a/crates/cheatcodes/src/inspector.rs b/crates/cheatcodes/src/inspector.rs index 1ff2a6e999dcc..f5238d810c8bd 100644 --- a/crates/cheatcodes/src/inspector.rs +++ b/crates/cheatcodes/src/inspector.rs @@ -9,9 +9,12 @@ use crate::{ }, inspector::utils::CommonCreateInput, script::{Broadcast, ScriptWallets}, - test::expect::{ - self, ExpectedCallData, ExpectedCallTracker, ExpectedCallType, ExpectedEmit, - ExpectedRevert, ExpectedRevertKind, + test::{ + assume::AssumeNoRevert, + expect::{ + self, ExpectedCallData, ExpectedCallTracker, ExpectedCallType, ExpectedEmit, + ExpectedRevert, ExpectedRevertKind, + }, }, utils::IgnoredTraces, CheatsConfig, CheatsCtxt, DynCheatcode, Error, Result, Vm, @@ -25,7 +28,7 @@ use foundry_config::Config; use foundry_evm_core::{ abi::Vm::stopExpectSafeMemoryCall, backend::{DatabaseExt, RevertDiagnostic}, - constants::{CHEATCODE_ADDRESS, HARDHAT_CONSOLE_ADDRESS}, + constants::{CHEATCODE_ADDRESS, HARDHAT_CONSOLE_ADDRESS, MAGIC_ASSUME}, utils::new_evm_with_existing_context, InspectorExt, }; @@ -294,6 +297,9 @@ pub struct Cheatcodes { /// Expected revert information pub expected_revert: Option, + /// Assume next call can revert and discard fuzz run if it does. + pub assume_no_revert: Option, + /// Additional diagnostic for reverts pub fork_revert_diagnostic: Option, @@ -384,6 +390,7 @@ impl Cheatcodes { gas_price: Default::default(), prank: Default::default(), expected_revert: Default::default(), + assume_no_revert: Default::default(), fork_revert_diagnostic: Default::default(), accesses: Default::default(), recorded_account_diffs_stack: Default::default(), @@ -1106,6 +1113,19 @@ impl Inspector for Cheatcodes { } } + // Handle assume not revert cheatcode. + if let Some(assume_no_revert) = &self.assume_no_revert { + if ecx.journaled_state.depth() == assume_no_revert.depth && !cheatcode_call { + // Discard run if we're at the same depth as cheatcode and call reverted. + if outcome.result.is_revert() { + outcome.result.output = Error::from(MAGIC_ASSUME).abi_encode().into(); + return outcome; + } + // Call didn't revert, reset `assume_no_revert` state. + self.assume_no_revert = None; + } + } + // Handle expected reverts if let Some(expected_revert) = &self.expected_revert { if ecx.journaled_state.depth() <= expected_revert.depth { diff --git a/crates/cheatcodes/src/test.rs b/crates/cheatcodes/src/test.rs index 279150dff89f4..0945f37a89c2a 100644 --- a/crates/cheatcodes/src/test.rs +++ b/crates/cheatcodes/src/test.rs @@ -3,25 +3,15 @@ use chrono::DateTime; use std::env; -use crate::{Cheatcode, Cheatcodes, CheatsCtxt, DatabaseExt, Error, Result, Vm::*}; +use crate::{Cheatcode, Cheatcodes, CheatsCtxt, DatabaseExt, Result, Vm::*}; use alloy_primitives::Address; use alloy_sol_types::SolValue; -use foundry_evm_core::constants::{MAGIC_ASSUME, MAGIC_SKIP}; +use foundry_evm_core::constants::MAGIC_SKIP; pub(crate) mod assert; +pub(crate) mod assume; pub(crate) mod expect; -impl Cheatcode for assumeCall { - fn apply(&self, _state: &mut Cheatcodes) -> Result { - let Self { condition } = self; - if *condition { - Ok(Default::default()) - } else { - Err(Error::from(MAGIC_ASSUME)) - } - } -} - impl Cheatcode for breakpoint_0Call { fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result { let Self { char } = self; diff --git a/crates/cheatcodes/src/test/assume.rs b/crates/cheatcodes/src/test/assume.rs new file mode 100644 index 0000000000000..e100eeb9d1c43 --- /dev/null +++ b/crates/cheatcodes/src/test/assume.rs @@ -0,0 +1,29 @@ +use crate::{Cheatcode, Cheatcodes, CheatsCtxt, Error, Result}; +use foundry_evm_core::{backend::DatabaseExt, constants::MAGIC_ASSUME}; +use spec::Vm::{assumeCall, assumeNoRevertCall}; +use std::fmt::Debug; + +#[derive(Clone, Debug)] +pub struct AssumeNoRevert { + /// The call depth at which the cheatcode was added. + pub depth: u64, +} + +impl Cheatcode for assumeCall { + fn apply(&self, _state: &mut Cheatcodes) -> Result { + let Self { condition } = self; + if *condition { + Ok(Default::default()) + } else { + Err(Error::from(MAGIC_ASSUME)) + } + } +} + +impl Cheatcode for assumeNoRevertCall { + fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result { + ccx.state.assume_no_revert = + Some(AssumeNoRevert { depth: ccx.ecx.journaled_state.depth() }); + Ok(Default::default()) + } +} diff --git a/crates/forge/tests/cli/test_cmd.rs b/crates/forge/tests/cli/test_cmd.rs index 23ffa890d6030..19e5c4853d3ba 100644 --- a/crates/forge/tests/cli/test_cmd.rs +++ b/crates/forge/tests/cli/test_cmd.rs @@ -1834,3 +1834,82 @@ contract CounterTest is DSTest { ... "#]]); }); + +forgetest_init!(test_assume_no_revert, |prj, cmd| { + prj.wipe_contracts(); + prj.insert_ds_test(); + prj.insert_vm(); + prj.clear(); + + prj.add_source( + "Counter.t.sol", + r#"pragma solidity 0.8.24; +import {Vm} from "./Vm.sol"; +import {DSTest} from "./test.sol"; +contract CounterWithRevert { + error CountError(); + error CheckError(); + + function count(uint256 a) public pure returns (uint256) { + if (a > 1000 || a < 10) { + revert CountError(); + } + return 99999999; + } + function check(uint256 a) public pure { + if (a == 99999999) { + revert CheckError(); + } + } + function dummy() public pure {} +} + +contract CounterRevertTest is DSTest { + Vm vm = Vm(HEVM_ADDRESS); + + function test_assume_no_revert_pass(uint256 a) public { + CounterWithRevert counter = new CounterWithRevert(); + vm.assumeNoRevert(); + a = counter.count(a); + assertEq(a, 99999999); + } + function test_assume_no_revert_fail_assert(uint256 a) public { + CounterWithRevert counter = new CounterWithRevert(); + vm.assumeNoRevert(); + a = counter.count(a); + // Test should fail on next assertion. + assertEq(a, 1); + } + function test_assume_no_revert_fail_in_2nd_call(uint256 a) public { + CounterWithRevert counter = new CounterWithRevert(); + vm.assumeNoRevert(); + a = counter.count(a); + // Test should revert here (not in scope of `assumeNoRevert` cheatcode). + counter.check(a); + assertEq(a, 99999999); + } + function test_assume_no_revert_fail_in_3rd_call(uint256 a) public { + CounterWithRevert counter = new CounterWithRevert(); + vm.assumeNoRevert(); + a = counter.count(a); + // Test `assumeNoRevert` applied to non reverting call should not be available for next reverting call. + vm.assumeNoRevert(); + counter.dummy(); + // Test will revert here (not in scope of `assumeNoRevert` cheatcode). + counter.check(a); + assertEq(a, 99999999); + } +} + "#, + ) + .unwrap(); + + cmd.args(["test"]).with_no_redact().assert_failure().stdout_eq(str![[r#" +... +[FAIL. Reason: assertion failed; counterexample: [..]] test_assume_no_revert_fail_assert(uint256) [..] +[FAIL. Reason: CheckError(); counterexample: [..]] test_assume_no_revert_fail_in_2nd_call(uint256) [..] +[FAIL. Reason: CheckError(); counterexample: [..]] test_assume_no_revert_fail_in_3rd_call(uint256) [..] +[PASS] test_assume_no_revert_pass(uint256) (runs: 256, [..]) +... +"#]]); +}); diff --git a/testdata/cheats/Vm.sol b/testdata/cheats/Vm.sol index b929053dac012..5b6750237addb 100644 --- a/testdata/cheats/Vm.sol +++ b/testdata/cheats/Vm.sol @@ -144,6 +144,7 @@ interface Vm { function assertTrue(bool condition) external pure; function assertTrue(bool condition, string calldata error) external pure; function assume(bool condition) external pure; + function assumeNoRevert() external pure; function blobBaseFee(uint256 newBlobBaseFee) external; function blobhashes(bytes32[] calldata hashes) external; function breakpoint(string calldata char) external;