diff --git a/faux_macros/src/methods/morphed.rs b/faux_macros/src/methods/morphed.rs index bc68a71..fd0b93c 100644 --- a/faux_macros/src/methods/morphed.rs +++ b/faux_macros/src/methods/morphed.rs @@ -171,7 +171,6 @@ impl<'a> Signature<'a> { let struct_and_method_name = format!("{}::{}", morphed_ty.to_token_stream(), name); quote! { - let mut q = q.try_lock().unwrap(); unsafe { match q.call_mock(::#faux_ident, #args) { std::result::Result::Ok(o) => o, @@ -276,7 +275,7 @@ impl<'a> MethodData<'a> { match &mut self.0 { faux::MaybeFaux::Faux(faux) => faux::When::new( ::#faux_ident, - faux.get_mut().unwrap() + faux ), faux::MaybeFaux::Real(_) => panic!("not allowed to mock a real instance!"), } diff --git a/src/mock_store.rs b/src/mock_store.rs index 7b57ec9..1928369 100644 --- a/src/mock_store.rs +++ b/src/mock_store.rs @@ -18,7 +18,7 @@ use std::collections::HashMap; #[derive(Debug)] pub enum MaybeFaux { Real(T), - Faux(std::sync::Mutex), + Faux(MockStore), } impl Clone for MaybeFaux { @@ -32,20 +32,20 @@ impl Clone for MaybeFaux { impl MaybeFaux { pub fn faux() -> Self { - MaybeFaux::Faux(std::sync::Mutex::new(MockStore::new())) + MaybeFaux::Faux(MockStore::new()) } } #[derive(Debug, Default)] #[doc(hidden)] pub struct MockStore { - mocks: HashMap>, + mocks: std::sync::Mutex>>>>, } impl MockStore { fn new() -> Self { MockStore { - mocks: HashMap::new(), + mocks: std::sync::Mutex::new(HashMap::new()), } } @@ -63,7 +63,10 @@ impl MockStore { } fn store_mock(&mut self, id: fn(R, I) -> O, mock: Mock<'static, I, O>) { - self.mocks.insert(id as usize, unsafe { mock.unchecked() }); + self.mocks.lock().unwrap().insert( + id as usize, + std::sync::Arc::new(std::sync::Mutex::new(unsafe { mock.unchecked() })), + ); } #[doc(hidden)] @@ -71,12 +74,15 @@ impl MockStore { /// /// Do *NOT* call this function directly. /// This should only be called by the generated code from #[faux::methods] - pub unsafe fn call_mock(&mut self, id: fn(R, I) -> O, input: I) -> Result { + pub unsafe fn call_mock(&self, id: fn(R, I) -> O, input: I) -> Result { let stub = self .mocks - .get_mut(&(id as usize)) + .lock() + .unwrap() + .get(&(id as usize)) + .map(|v| v.clone()) .ok_or_else(|| "method was never mocked".to_string())?; - - stub.call(input) + let mut locked = stub.lock().unwrap(); + locked.call(input) } } diff --git a/tests/threads.rs b/tests/threads.rs new file mode 100644 index 0000000..9f11c02 --- /dev/null +++ b/tests/threads.rs @@ -0,0 +1,105 @@ +use std::sync::atomic::AtomicUsize; +use std::sync::atomic::Ordering; +use std::sync::Arc; +use std::sync::Mutex; +use std::time::Duration; + +#[faux::create] +pub struct Foo {} + +#[faux::methods] +impl Foo { + pub fn foo(&self) { + todo!() + } + pub fn bar(&self) { + todo!() + } +} + +#[test] +fn mock_multi_threaded_access() { + let mut fake = Foo::faux(); + let done_count = Arc::new(AtomicUsize::new(0)); + + faux::when!(fake.bar).then(move |()| {}); + + let shared_fake1 = Arc::new(fake); + let shared_fake2 = shared_fake1.clone(); + + let dc1 = done_count.clone(); + let _t1 = std::thread::spawn(move || { + for _ in 0..10000 { + shared_fake1.bar(); + } + dc1.fetch_add(1, Ordering::Relaxed); + }); + + let dc2 = done_count.clone(); + let _t2 = std::thread::spawn(move || { + for _ in 0..10000 { + shared_fake2.bar(); + } + dc2.fetch_add(1, Ordering::Relaxed); + }); + + std::thread::sleep(Duration::from_millis(100)); // FIXME maybe we can do better? + assert_eq!(done_count.load(Ordering::Relaxed), 2); +} + +fn spin_until(a: &Arc, val: usize) { + loop { + if a.load(Ordering::SeqCst) == val { + break; + } + } +} + +#[test] +fn mutex_does_not_lock_entire_mock() { + // Assume calling a function lock the entire mock. Then a following scenario can happen: + // * Thread 1 takes a lock L. + // * Thread 2 calls mocked foo(), which tries to take and block on L. + // * While holding L, thread 1 calls mocked bar(), blocking on the mock. + // * We get a deadlock, even though bar() is seemingly unrelated to lock-taking foo(). + + let mut fake = Foo::faux(); + let l = Arc::new(Mutex::new(10)); + let l_foo = l.clone(); + + let done_count = Arc::new(AtomicUsize::new(0)); + let call_order = Arc::new(AtomicUsize::new(0)); + + let co_foo = call_order.clone(); + faux::when!(fake.foo).then(move |()| { + co_foo.swap(2, Ordering::SeqCst); // Let thread 1 call bar() + let _ = l_foo.lock(); + spin_until(&co_foo, 3); // Hold the lock until thread 1 returns from bar() + }); + faux::when!(fake.bar).then(move |()| {}); + + let shared_fake1 = Arc::new(fake); + let shared_fake2 = shared_fake1.clone(); + + let dc1 = done_count.clone(); + let co1 = call_order.clone(); + let _t1 = std::thread::spawn(move || { + let _ = l.lock(); + co1.swap(1, Ordering::SeqCst); + spin_until(&co1, 2); // Wait for thread 2 to call foo + shared_fake1.bar(); + co1.swap(3, Ordering::SeqCst); + dc1.fetch_add(1, Ordering::Relaxed); + }); + + let dc2 = done_count.clone(); + let co2 = call_order.clone(); + let _t2 = std::thread::spawn(move || { + spin_until(&co2, 1); // Wait for thread 1 to grab L + shared_fake2.foo(); + dc2.fetch_add(1, Ordering::Relaxed); + }); + + std::thread::sleep(Duration::from_millis(100)); // FIXME maybe we can do better? + assert_eq!(done_count.load(Ordering::Relaxed), 2); +}