|
2 | 2 | #![warn(rust_2018_idioms)]
|
3 | 3 | #![cfg(all(feature = "full", not(target_os = "wasi"), target_has_atomic = "64"))]
|
4 | 4 |
|
5 |
| -use std::sync::{Arc, Barrier}; |
| 5 | +use std::sync::mpsc; |
| 6 | +use std::time::Duration; |
6 | 7 | use tokio::runtime::Runtime;
|
7 | 8 |
|
8 | 9 | #[test]
|
@@ -68,36 +69,51 @@ fn global_queue_depth_multi_thread() {
|
68 | 69 | let rt = threaded();
|
69 | 70 | let metrics = rt.metrics();
|
70 | 71 |
|
71 |
| - let barrier1 = Arc::new(Barrier::new(3)); |
72 |
| - let barrier2 = Arc::new(Barrier::new(3)); |
73 |
| - |
74 |
| - // Spawn a task per runtime worker to block it. |
75 |
| - for _ in 0..2 { |
76 |
| - let barrier1 = barrier1.clone(); |
77 |
| - let barrier2 = barrier2.clone(); |
78 |
| - rt.spawn(async move { |
79 |
| - barrier1.wait(); |
80 |
| - barrier2.wait(); |
81 |
| - }); |
82 |
| - } |
83 |
| - |
84 |
| - barrier1.wait(); |
| 72 | + for _ in 0..10 { |
| 73 | + if let Ok(_blocking_tasks) = try_block_threaded(&rt) { |
| 74 | + for i in 0..10 { |
| 75 | + assert_eq!(i, metrics.global_queue_depth()); |
| 76 | + rt.spawn(async {}); |
| 77 | + } |
85 | 78 |
|
86 |
| - let mut fail: Option<String> = None; |
87 |
| - for i in 0..10 { |
88 |
| - let depth = metrics.global_queue_depth(); |
89 |
| - if i != depth { |
90 |
| - fail = Some(format!("{i} is not equal to {depth}")); |
91 |
| - break; |
| 79 | + return; |
92 | 80 | }
|
93 |
| - rt.spawn(async {}); |
94 | 81 | }
|
95 | 82 |
|
96 |
| - barrier2.wait(); |
| 83 | + panic!("exhausted every try to block the runtime"); |
| 84 | +} |
97 | 85 |
|
98 |
| - if let Some(fail) = fail { |
99 |
| - panic!("{fail}"); |
| 86 | +fn try_block_threaded(rt: &Runtime) -> Result<Vec<mpsc::Sender<()>>, mpsc::RecvTimeoutError> { |
| 87 | + let (tx, rx) = mpsc::channel(); |
| 88 | + |
| 89 | + let blocking_tasks = (0..rt.metrics().num_workers()) |
| 90 | + .map(|_| { |
| 91 | + let tx = tx.clone(); |
| 92 | + let (task, barrier) = mpsc::channel(); |
| 93 | + |
| 94 | + // Spawn a task per runtime worker to block it. |
| 95 | + rt.spawn(async move { |
| 96 | + tx.send(()).unwrap(); |
| 97 | + barrier.recv().ok(); |
| 98 | + }); |
| 99 | + |
| 100 | + task |
| 101 | + }) |
| 102 | + .collect(); |
| 103 | + |
| 104 | + // Make sure the previously spawned tasks are blocking the runtime by |
| 105 | + // receiving a message from each blocking task. |
| 106 | + // |
| 107 | + // If this times out we were unsuccessful in blocking the runtime and hit |
| 108 | + // a deadlock instead (which might happen and is expected behaviour). |
| 109 | + for _ in 0..rt.metrics().num_workers() { |
| 110 | + rx.recv_timeout(Duration::from_secs(1))?; |
100 | 111 | }
|
| 112 | + |
| 113 | + // Return senders of the mpsc channels used for blocking the runtime as a |
| 114 | + // surrogate handle for the tasks. Sending a message or dropping the senders |
| 115 | + // will unblock the runtime. |
| 116 | + Ok(blocking_tasks) |
101 | 117 | }
|
102 | 118 |
|
103 | 119 | fn current_thread() -> Runtime {
|
|
0 commit comments