Skip to content

Commit

Permalink
fix: hold runtime ref and handle to prevent spawn after shutdown (#736)
Browse files Browse the repository at this point in the history
* fix: hold runtime ref and handle to prevent spawn after shutdown

Signed-off-by: MrCroxx <[email protected]>

* fix: make ffmt happy

Signed-off-by: MrCroxx <[email protected]>

* fix: make device hold runtime, too

Signed-off-by: MrCroxx <[email protected]>

---------

Signed-off-by: MrCroxx <[email protected]>
  • Loading branch information
MrCroxx authored Sep 24, 2024
1 parent 5fe3957 commit b49e275
Show file tree
Hide file tree
Showing 20 changed files with 330 additions and 178 deletions.
8 changes: 4 additions & 4 deletions foyer-common/src/asyncify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use tokio::runtime::Handle;
use crate::runtime::SingletonHandle;

/// Convert the block call to async call.
#[cfg(not(madsim))]
Expand All @@ -36,9 +36,9 @@ where
f()
}

/// Convert the block call to async call with given runtime.
/// Convert the block call to async call with given runtime handle.
#[cfg(not(madsim))]
pub async fn asyncify_with_runtime<F, T>(runtime: &Handle, f: F) -> T
pub async fn asyncify_with_runtime<F, T>(runtime: &SingletonHandle, f: F) -> T
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
Expand All @@ -50,7 +50,7 @@ where
/// Convert the block call to async call with given runtime.
///
/// madsim compatible mode.
pub async fn asyncify_with_runtime<F, T>(_: &Handle, f: F) -> T
pub async fn asyncify_with_runtime<F, T>(_: &SingletonHandle, f: F) -> T
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
Expand Down
84 changes: 83 additions & 1 deletion foyer-common/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,15 @@

use std::{
fmt::Debug,
future::Future,
mem::ManuallyDrop,
ops::{Deref, DerefMut},
};

use tokio::runtime::Runtime;
use tokio::{
runtime::{Handle, Runtime},
task::JoinHandle,
};

/// A wrapper around [`Runtime`] that shuts down the runtime in the background when dropped.
///
Expand Down Expand Up @@ -62,3 +66,81 @@ impl From<Runtime> for BackgroundShutdownRuntime {
Self(ManuallyDrop::new(runtime))
}
}

/// A non-clonable runtime handle.
#[derive(Debug)]
pub struct SingletonHandle(Handle);

impl From<Handle> for SingletonHandle {
fn from(handle: Handle) -> Self {
Self(handle)
}
}

impl SingletonHandle {
/// Spawns a future onto the Tokio runtime.
///
/// This spawns the given future onto the runtime's executor, usually a
/// thread pool. The thread pool is then responsible for polling the future
/// until it completes.
///
/// The provided future will start running in the background immediately
/// when `spawn` is called, even if you don't await the returned
/// `JoinHandle`.
///
/// See [module level][mod] documentation for more details.
///
/// [mod]: index.html
///
/// # Examples
///
/// ```
/// use tokio::runtime::Runtime;
///
/// # fn dox() {
/// // Create the runtime
/// let rt = Runtime::new().unwrap();
/// // Get a handle from this runtime
/// let handle = rt.handle();
///
/// // Spawn a future onto the runtime using the handle
/// handle.spawn(async {
/// println!("now running on a worker thread");
/// });
/// # }
/// ```
pub fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
self.0.spawn(future)
}

/// Runs the provided function on an executor dedicated to blocking
/// operations.
///
/// # Examples
///
/// ```
/// use tokio::runtime::Runtime;
///
/// # fn dox() {
/// // Create the runtime
/// let rt = Runtime::new().unwrap();
/// // Get a handle from this runtime
/// let handle = rt.handle();
///
/// // Spawn a blocking function onto the runtime using the handle
/// handle.spawn_blocking(|| {
/// println!("now running on a worker thread");
/// });
/// # }
pub fn spawn_blocking<F, R>(&self, func: F) -> JoinHandle<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
self.0.spawn_blocking(func)
}
}
3 changes: 2 additions & 1 deletion foyer-memory/src/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use foyer_common::{
code::{HashBuilder, Key, Value},
event::EventListener,
future::Diversion,
runtime::SingletonHandle,
};
use futures::Future;
use pin_project::pin_project;
Expand Down Expand Up @@ -834,7 +835,7 @@ where
key: K,
context: CacheContext,
fetch: F,
runtime: &tokio::runtime::Handle,
runtime: &SingletonHandle,
) -> Fetch<K, V, ER, S>
where
F: FnOnce() -> FU,
Expand Down
13 changes: 10 additions & 3 deletions foyer-memory/src/generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ use foyer_common::{
future::{Diversion, DiversionFuture},
metrics::Metrics,
object_pool::ObjectPool,
runtime::SingletonHandle,
strict_assert, strict_assert_eq,
};
use hashbrown::hash_map::{Entry as HashMapEntry, HashMap};
Expand Down Expand Up @@ -739,7 +740,12 @@ where
FU: Future<Output = std::result::Result<V, ER>> + Send + 'static,
ER: Send + 'static + Debug,
{
self.fetch_inner(key, CacheContext::default(), fetch, &tokio::runtime::Handle::current())
self.fetch_inner(
key,
CacheContext::default(),
fetch,
&tokio::runtime::Handle::current().into(),
)
}

pub fn fetch_with_context<F, FU, ER>(
Expand All @@ -753,15 +759,16 @@ where
FU: Future<Output = std::result::Result<V, ER>> + Send + 'static,
ER: Send + 'static + Debug,
{
self.fetch_inner(key, context, fetch, &tokio::runtime::Handle::current())
self.fetch_inner(key, context, fetch, &tokio::runtime::Handle::current().into())
}

#[doc(hidden)]
pub fn fetch_inner<F, FU, ER, ID>(
self: &Arc<Self>,
key: K,
context: CacheContext,
fetch: F,
runtime: &tokio::runtime::Handle,
runtime: &SingletonHandle,
) -> GenericFetch<K, V, E, I, S, ER>
where
F: FnOnce() -> FU,
Expand Down
20 changes: 9 additions & 11 deletions foyer-storage/src/device/direct_file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,12 @@ use std::{
use foyer_common::{asyncify::asyncify_with_runtime, bits};
use fs4::free_space;
use serde::{Deserialize, Serialize};
use tokio::runtime::Handle;

use super::{Dev, DevExt, DevOptions, RegionId};
use crate::{
device::ALIGN,
error::{Error, Result},
IoBytes, IoBytesMut,
IoBytes, IoBytesMut, Runtime,
};

/// Options for the direct file device.
Expand All @@ -49,7 +48,7 @@ pub struct DirectFileDevice {
capacity: usize,
region_size: usize,

runtime: Handle,
runtime: Runtime,
}

impl DevOptions for DirectFileDeviceOptions {
Expand Down Expand Up @@ -90,7 +89,7 @@ impl DirectFileDevice {

let file = self.file.clone();

asyncify_with_runtime(&self.runtime, move || {
asyncify_with_runtime(self.runtime.write(), move || {
#[cfg(target_family = "windows")]
let written = {
use std::os::windows::fs::FileExt;
Expand Down Expand Up @@ -133,7 +132,7 @@ impl DirectFileDevice {

let file = self.file.clone();

let mut buffer = asyncify_with_runtime(&self.runtime, move || {
let mut buffer = asyncify_with_runtime(self.runtime.read(), move || {
#[cfg(target_family = "windows")]
let read = {
use std::os::windows::fs::FileExt;
Expand Down Expand Up @@ -172,9 +171,7 @@ impl Dev for DirectFileDevice {
}

#[fastrace::trace(name = "foyer::storage::device::direct_file::open")]
async fn open(options: Self::Options) -> Result<Self> {
let runtime = Handle::current();

async fn open(options: Self::Options, runtime: Runtime) -> Result<Self> {
options.verify()?;

let dir = options
Expand Down Expand Up @@ -253,7 +250,7 @@ impl Dev for DirectFileDevice {
#[fastrace::trace(name = "foyer::storage::device::direct_file::flush")]
async fn flush(&self, _: Option<RegionId>) -> Result<()> {
let file = self.file.clone();
asyncify_with_runtime(&self.runtime, move || file.sync_all().map_err(Error::from)).await
asyncify_with_runtime(self.runtime.write(), move || file.sync_all().map_err(Error::from)).await
}
}

Expand Down Expand Up @@ -360,6 +357,7 @@ mod tests {
#[test_log::test(tokio::test)]
async fn test_direct_file_device_io() {
let dir = tempfile::tempdir().unwrap();
let runtime = Runtime::current();

let options = DirectFileDeviceOptionsBuilder::new(dir.path().join("test-direct-file"))
.with_capacity(4 * 1024 * 1024)
Expand All @@ -368,7 +366,7 @@ mod tests {

tracing::debug!("{options:?}");

let device = DirectFileDevice::open(options.clone()).await.unwrap();
let device = DirectFileDevice::open(options.clone(), runtime.clone()).await.unwrap();

let mut buf = IoBytesMut::with_capacity(64 * 1024);
buf.extend(repeat_n(b'x', 64 * 1024 - 100));
Expand All @@ -383,7 +381,7 @@ mod tests {

drop(device);

let device = DirectFileDevice::open(options).await.unwrap();
let device = DirectFileDevice::open(options, runtime).await.unwrap();

let b = device.read(0, 4096, 64 * 1024 - 100).await.unwrap().freeze();
assert_eq!(buf, b);
Expand Down
25 changes: 12 additions & 13 deletions foyer-storage/src/device/direct_fs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,12 @@ use fs4::free_space;
use futures::future::try_join_all;
use itertools::Itertools;
use serde::{Deserialize, Serialize};
use tokio::runtime::Handle;

use super::{Dev, DevExt, DevOptions, RegionId};
use crate::{
device::ALIGN,
error::{Error, Result},
IoBytes, IoBytesMut,
IoBytes, IoBytesMut, Runtime,
};

/// Options for the direct fs device.
Expand All @@ -56,7 +55,7 @@ struct DirectFsDeviceInner {
capacity: usize,
file_size: usize,

runtime: Handle,
runtime: Runtime,
}

impl DevOptions for DirectFsDeviceOptions {
Expand Down Expand Up @@ -106,17 +105,16 @@ impl Dev for DirectFsDevice {
}

#[fastrace::trace(name = "foyer::storage::device::direct_fs::open")]
async fn open(options: Self::Options) -> Result<Self> {
let runtime = Handle::current();

async fn open(options: Self::Options, runtime: Runtime) -> Result<Self> {
options.verify()?;

// TODO(MrCroxx): write and read options to a manifest file for pinning

let regions = options.capacity / options.file_size;

let path = options.dir.clone();
asyncify_with_runtime(&runtime, move || create_dir_all(path)).await?;
if !options.dir.exists() {
create_dir_all(&options.dir)?;
}

let futures = (0..regions)
.map(|i| {
Expand Down Expand Up @@ -165,7 +163,7 @@ impl Dev for DirectFsDevice {

let file = self.file(region).clone();

asyncify_with_runtime(&self.inner.runtime, move || {
asyncify_with_runtime(self.inner.runtime.write(), move || {
#[cfg(target_family = "windows")]
let written = {
use std::os::windows::fs::FileExt;
Expand Down Expand Up @@ -207,7 +205,7 @@ impl Dev for DirectFsDevice {

let file = self.file(region).clone();

let mut buffer = asyncify_with_runtime(&self.inner.runtime, move || {
let mut buffer = asyncify_with_runtime(self.inner.runtime.read(), move || {
#[cfg(target_family = "unix")]
let read = {
use std::os::unix::fs::FileExt;
Expand Down Expand Up @@ -237,7 +235,7 @@ impl Dev for DirectFsDevice {
async fn flush(&self, region: Option<super::RegionId>) -> Result<()> {
let flush = |region: RegionId| {
let file = self.file(region).clone();
asyncify_with_runtime(&self.inner.runtime, move || file.sync_all().map_err(Error::from))
asyncify_with_runtime(self.inner.runtime.write(), move || file.sync_all().map_err(Error::from))
};

if let Some(region) = region {
Expand Down Expand Up @@ -352,6 +350,7 @@ mod tests {
#[test_log::test(tokio::test)]
async fn test_direct_fd_device_io() {
let dir = tempfile::tempdir().unwrap();
let runtime = Runtime::current();

let options = DirectFsDeviceOptionsBuilder::new(dir.path())
.with_capacity(4 * 1024 * 1024)
Expand All @@ -360,7 +359,7 @@ mod tests {

tracing::debug!("{options:?}");

let device = DirectFsDevice::open(options.clone()).await.unwrap();
let device = DirectFsDevice::open(options.clone(), runtime.clone()).await.unwrap();

let mut buf = IoBytesMut::with_capacity(64 * 1024);
buf.extend(repeat_n(b'x', 64 * 1024 - 100));
Expand All @@ -375,7 +374,7 @@ mod tests {

drop(device);

let device = DirectFsDevice::open(options).await.unwrap();
let device = DirectFsDevice::open(options, runtime).await.unwrap();

let b = device.read(0, 4096, 64 * 1024 - 100).await.unwrap().freeze();
assert_eq!(buf, b);
Expand Down
Loading

0 comments on commit b49e275

Please sign in to comment.