Skip to content

Commit

Permalink
feature: add the onchain middleware when createSelectFork (#448)
Browse files Browse the repository at this point in the history
* refactor: split cheatcode into submodules

* feature: add the onchain middleware when createSelectFork
  • Loading branch information
jacob-chia authored Apr 1, 2024
1 parent 27ddefd commit 8c88b03
Show file tree
Hide file tree
Showing 7 changed files with 886 additions and 686 deletions.
35 changes: 19 additions & 16 deletions src/evm/host.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::{
ops::Deref,
rc::Rc,
str::FromStr,
sync::Arc,
sync::{Arc, RwLock},
time::{SystemTime, UNIX_EPOCH},
};

Expand Down Expand Up @@ -174,7 +174,9 @@ where
pub pc_to_create: HashMap<(EVMAddress, usize), usize>,
pub pc_to_call_hash: HashMap<(EVMAddress, usize, usize), HashSet<Vec<u8>>>,
pub middlewares_enabled: bool,
pub middlewares: Rc<RefCell<Vec<Rc<RefCell<dyn Middleware<SC>>>>>>,
// If you use RefCell, modifying middlewares during execution will cause a panic
// because the executor borrows middlewares over its entire lifetime.
pub middlewares: RwLock<Vec<Rc<RefCell<dyn Middleware<SC>>>>>,

pub coverage_changed: bool,

Expand Down Expand Up @@ -277,7 +279,7 @@ where
pc_to_create: self.pc_to_create.clone(),
pc_to_call_hash: self.pc_to_call_hash.clone(),
middlewares_enabled: false,
middlewares: Rc::new(RefCell::new(Default::default())),
middlewares: RwLock::new(Default::default()),
coverage_changed: false,
flashloan_middleware: self.flashloan_middleware.clone(),
middlewares_latent_call_actions: vec![],
Expand Down Expand Up @@ -338,7 +340,7 @@ where
pc_to_create: HashMap::new(),
pc_to_call_hash: HashMap::new(),
middlewares_enabled: false,
middlewares: Rc::new(RefCell::new(Default::default())),
middlewares: RwLock::new(Default::default()),
coverage_changed: false,
flashloan_middleware: None,
middlewares_latent_call_actions: vec![],
Expand Down Expand Up @@ -404,27 +406,27 @@ where

pub fn remove_all_middlewares(&mut self) {
self.middlewares_enabled = false;
self.middlewares.deref().borrow_mut().clear();
self.middlewares = RwLock::new(Default::default());
}

pub fn add_middlewares(&mut self, middlewares: Rc<RefCell<dyn Middleware<SC>>>) {
pub fn add_middlewares(&mut self, middleware: Rc<RefCell<dyn Middleware<SC>>>) {
self.middlewares_enabled = true;
// let ty = middlewares.deref().borrow().get_type();
self.middlewares.deref().borrow_mut().push(middlewares);
self.middlewares.write().unwrap().push(middleware);
}

pub fn remove_middlewares(&mut self, middlewares: Rc<RefCell<dyn Middleware<SC>>>) {
let ty = middlewares.deref().borrow().get_type();

self.middlewares
.deref()
.borrow_mut()
.write()
.unwrap()
.retain(|x| x.deref().borrow().get_type() != ty);
}

pub fn remove_middlewares_by_ty(&mut self, ty: &MiddlewareType) {
self.middlewares
.deref()
.borrow_mut()
.write()
.unwrap()
.retain(|x| x.deref().borrow().get_type() != *ty);
}

Expand Down Expand Up @@ -986,8 +988,9 @@ macro_rules! invoke_middlewares {
$host.clear_codedata();
}

for middleware in &mut $host.middlewares.clone().deref().borrow_mut().iter_mut() {
middleware.deref().deref().borrow_mut().$invoke($interp, $host, $state $(, $arg)*);
let mut middlewares = $host.middlewares.read().unwrap().clone();
for middleware in middlewares.iter_mut() {
middleware.deref().borrow_mut().$invoke($interp, $host, $state $(, $arg)*);
}

if !$host.setcode_data.is_empty() {
Expand Down Expand Up @@ -1464,9 +1467,9 @@ where

unsafe {
if self.middlewares_enabled {
for middleware in &mut self.middlewares.clone().deref().borrow_mut().iter_mut() {
let mut middlewares = self.middlewares.read().unwrap().clone();
for middleware in middlewares.iter_mut() {
middleware
.deref()
.deref()
.borrow_mut()
.on_return(interp, self, state, &ret_buffer);
Expand Down
Loading

0 comments on commit 8c88b03

Please sign in to comment.