diff --git a/README.md b/README.md index e2bb7ef..ab00c4a 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ function-cache wrapped in a mutex/rwlock, or externally synchronized in the case By default, the function-cache is **not** locked for the duration of the function's execution, so initial (on an empty cache) concurrent calls of long-running functions with the same arguments will each execute fully and each overwrite the memoized value as they complete. This mirrors the behavior of Python's `functools.lru_cache`. To synchronize the execution and caching -of un-cached arguments, specify `#[cached(sync_writes = true)]` / `#[once(sync_writes = true)]` (not supported by `#[io_cached]`. +of un-cached arguments, specify `#[cached(sync_writes = "default")]` / `#[once(sync_writes = "default")]` (not supported by `#[io_cached]`. - See [`cached::stores` docs](https://docs.rs/cached/latest/cached/stores/index.html) cache stores available. - See [`proc_macro`](https://docs.rs/cached/latest/cached/proc_macro/index.html) for more procedural macro examples. @@ -93,7 +93,7 @@ use cached::proc_macro::once; /// When no (or expired) cache, concurrent calls /// will synchronize (`sync_writes`) so the function /// is only executed once. -#[once(time=10, option = true, sync_writes = true)] +#[once(time=10, option = true, sync_writes = "default")] fn keyed(a: String) -> Option { if a == "a" { Some(a.len()) @@ -112,7 +112,7 @@ use cached::proc_macro::cached; #[cached( result = true, time = 1, - sync_writes = true, + sync_writes = "default", result_fallback = true )] fn doesnt_compile() -> Result { diff --git a/cached_proc_macro/src/cached.rs b/cached_proc_macro/src/cached.rs index fc63555..4eec8b8 100644 --- a/cached_proc_macro/src/cached.rs +++ b/cached_proc_macro/src/cached.rs @@ -3,9 +3,17 @@ use darling::ast::NestedMeta; use darling::FromMeta; use proc_macro::TokenStream; use quote::quote; +use std::cmp::PartialEq; use syn::spanned::Spanned; use syn::{parse_macro_input, parse_str, Block, Ident, ItemFn, ReturnType, Type}; +#[derive(Debug, Default, FromMeta, Eq, PartialEq)] +enum SyncWriteMode { + #[default] + Default, + ByKey, +} + #[derive(FromMeta)] struct MacroArgs { #[darling(default)] @@ -27,7 +35,7 @@ struct MacroArgs { #[darling(default)] option: bool, #[darling(default)] - sync_writes: bool, + sync_writes: Option, #[darling(default)] with_cached_flag: bool, #[darling(default)] @@ -190,8 +198,8 @@ pub fn cached(args: TokenStream, input: TokenStream) -> TokenStream { _ => panic!("the result and option attributes are mutually exclusive"), }; - if args.result_fallback && args.sync_writes { - panic!("the result_fallback and sync_writes attributes are mutually exclusive"); + if args.result_fallback && args.sync_writes.is_some() { + panic!("result_fallback and sync_writes are mutually exclusive"); } let set_cache_and_return = quote! { @@ -206,8 +214,19 @@ pub fn cached(args: TokenStream, input: TokenStream) -> TokenStream { let function_call; let ty; if asyncness.is_some() { - lock = quote! { - let mut cache = #cache_ident.lock().await; + lock = match args.sync_writes { + Some(SyncWriteMode::ByKey) => quote! { + let mut locks = #cache_ident.lock().await; + let lock = locks + .entry(key.clone()) + .or_insert_with(|| std::sync::Arc::new(::cached::async_sync::Mutex::new(#cache_create))) + .clone(); + drop(locks); + let mut cache = lock.lock().await; + }, + _ => quote! { + let mut cache = #cache_ident.lock().await; + }, }; function_no_cache = quote! { @@ -218,12 +237,25 @@ pub fn cached(args: TokenStream, input: TokenStream) -> TokenStream { let result = #no_cache_fn_ident(#(#input_names),*).await; }; - ty = quote! { - #visibility static #cache_ident: ::cached::once_cell::sync::Lazy<::cached::async_sync::Mutex<#cache_ty>> = ::cached::once_cell::sync::Lazy::new(|| ::cached::async_sync::Mutex::new(#cache_create)); + ty = match args.sync_writes { + Some(SyncWriteMode::ByKey) => quote! { + #visibility static #cache_ident: ::cached::once_cell::sync::Lazy<::cached::async_sync::Mutex>>>> = ::cached::once_cell::sync::Lazy::new(|| ::cached::async_sync::Mutex::new(std::collections::HashMap::new())); + }, + _ => quote! { + #visibility static #cache_ident: ::cached::once_cell::sync::Lazy<::cached::async_sync::Mutex<#cache_ty>> = ::cached::once_cell::sync::Lazy::new(|| ::cached::async_sync::Mutex::new(#cache_create)); + }, }; } else { - lock = quote! { - let mut cache = #cache_ident.lock().unwrap(); + lock = match args.sync_writes { + Some(SyncWriteMode::ByKey) => quote! { + let mut locks = #cache_ident.lock().unwrap(); + let lock = locks.entry(key.clone()).or_insert_with(|| std::sync::Arc::new(std::sync::Mutex::new(#cache_create))).clone(); + drop(locks); + let mut cache = lock.lock().unwrap(); + }, + _ => quote! { + let mut cache = #cache_ident.lock().unwrap(); + }, }; function_no_cache = quote! { @@ -234,9 +266,14 @@ pub fn cached(args: TokenStream, input: TokenStream) -> TokenStream { let result = #no_cache_fn_ident(#(#input_names),*); }; - ty = quote! { - #visibility static #cache_ident: ::cached::once_cell::sync::Lazy> = ::cached::once_cell::sync::Lazy::new(|| std::sync::Mutex::new(#cache_create)); - }; + ty = match args.sync_writes { + Some(SyncWriteMode::ByKey) => quote! { + #visibility static #cache_ident: ::cached::once_cell::sync::Lazy>>>> = ::cached::once_cell::sync::Lazy::new(|| std::sync::Mutex::new(std::collections::HashMap::new())); + }, + _ => quote! { + #visibility static #cache_ident: ::cached::once_cell::sync::Lazy> = ::cached::once_cell::sync::Lazy::new(|| std::sync::Mutex::new(#cache_create)); + }, + } } let prime_do_set_return_block = quote! { @@ -247,7 +284,7 @@ pub fn cached(args: TokenStream, input: TokenStream) -> TokenStream { #set_cache_and_return }; - let do_set_return_block = if args.sync_writes { + let do_set_return_block = if args.sync_writes.is_some() { quote! { #lock if let Some(result) = cache.cache_get(&key) { diff --git a/cached_proc_macro/src/lib.rs b/cached_proc_macro/src/lib.rs index 83d6d39..5985017 100644 --- a/cached_proc_macro/src/lib.rs +++ b/cached_proc_macro/src/lib.rs @@ -14,6 +14,7 @@ use proc_macro::TokenStream; /// - `time`: (optional, u64) specify a cache TTL in seconds, implies the cache type is a `TimedCache` or `TimedSizedCache`. /// - `time_refresh`: (optional, bool) specify whether to refresh the TTL on cache hits. /// - `sync_writes`: (optional, bool) specify whether to synchronize the execution of writing of uncached values. +/// - `sync_writes_by_key`: (optional, bool) specify whether to synchronize the execution of writing of uncached values by key. /// - `ty`: (optional, string type) The cache store type to use. Defaults to `UnboundCache`. When `unbound` is /// specified, defaults to `UnboundCache`. When `size` is specified, defaults to `SizedCache`. /// When `time` is specified, defaults to `TimedCached`. diff --git a/cached_proc_macro/src/once.rs b/cached_proc_macro/src/once.rs index 70d5617..907f74b 100644 --- a/cached_proc_macro/src/once.rs +++ b/cached_proc_macro/src/once.rs @@ -6,6 +6,12 @@ use quote::quote; use syn::spanned::Spanned; use syn::{parse_macro_input, Ident, ItemFn, ReturnType}; +#[derive(Debug, Default, FromMeta)] +enum SyncWriteMode { + #[default] + Default, +} + #[derive(FromMeta)] struct OnceMacroArgs { #[darling(default)] @@ -13,7 +19,7 @@ struct OnceMacroArgs { #[darling(default)] time: Option, #[darling(default)] - sync_writes: bool, + sync_writes: Option, #[darling(default)] result: bool, #[darling(default)] @@ -220,8 +226,8 @@ pub fn once(args: TokenStream, input: TokenStream) -> TokenStream { } }; - let do_set_return_block = if args.sync_writes { - quote! { + let do_set_return_block = match args.sync_writes { + Some(SyncWriteMode::Default) => quote! { #r_lock_return_cache_block #w_lock if let Some(result) = &*cached { @@ -229,14 +235,13 @@ pub fn once(args: TokenStream, input: TokenStream) -> TokenStream { } #function_call #set_cache_and_return - } - } else { - quote! { + }, + None => quote! { #r_lock_return_cache_block #function_call #w_lock #set_cache_and_return - } + }, }; let signature_no_muts = get_mut_signature(signature); diff --git a/examples/async_std.rs b/examples/async_std.rs index a619993..53740e8 100644 --- a/examples/async_std.rs +++ b/examples/async_std.rs @@ -86,7 +86,7 @@ async fn only_cached_once_per_second(s: String) -> Vec { /// _one_ call will be "executed" and all others will be synchronized /// to return the cached result of the one call instead of all /// concurrently un-cached tasks executing and writing concurrently. -#[once(time = 2, sync_writes = true)] +#[once(time = 2, sync_writes = "default")] async fn only_cached_once_per_second_sync_writes(s: String) -> Vec { vec![s] } diff --git a/src/lib.rs b/src/lib.rs index 163f322..67d8f9b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,7 +13,7 @@ function-cache wrapped in a mutex/rwlock, or externally synchronized in the case By default, the function-cache is **not** locked for the duration of the function's execution, so initial (on an empty cache) concurrent calls of long-running functions with the same arguments will each execute fully and each overwrite the memoized value as they complete. This mirrors the behavior of Python's `functools.lru_cache`. To synchronize the execution and caching -of un-cached arguments, specify `#[cached(sync_writes = true)]` / `#[once(sync_writes = true)]` (not supported by `#[io_cached]`. +of un-cached arguments, specify `#[cached(sync_writes = "default")]` / `#[once(sync_writes = "default")]` (not supported by `#[io_cached]`. - See [`cached::stores` docs](https://docs.rs/cached/latest/cached/stores/index.html) cache stores available. - See [`proc_macro`](https://docs.rs/cached/latest/cached/proc_macro/index.html) for more procedural macro examples. @@ -94,7 +94,7 @@ use cached::proc_macro::once; /// When no (or expired) cache, concurrent calls /// will synchronize (`sync_writes`) so the function /// is only executed once. -#[once(time=10, option = true, sync_writes = true)] +#[once(time=10, option = true, sync_writes = "default")] fn keyed(a: String) -> Option { if a == "a" { Some(a.len()) @@ -114,7 +114,7 @@ use cached::proc_macro::cached; #[cached( result = true, time = 1, - sync_writes = true, + sync_writes = "default", result_fallback = true )] fn doesnt_compile() -> Result { diff --git a/src/proc_macro.rs b/src/proc_macro.rs index b61a702..9453d8d 100644 --- a/src/proc_macro.rs +++ b/src/proc_macro.rs @@ -115,7 +115,7 @@ use cached::proc_macro::cached; /// When called concurrently, duplicate argument-calls will be /// synchronized so as to only run once - the remaining concurrent /// calls return a cached value. -#[cached(size=1, option = true, sync_writes = true)] +#[cached(size=1, option = true, sync_writes = "default")] fn keyed(a: String) -> Option { if a == "a" { Some(a.len()) @@ -233,7 +233,7 @@ use cached::proc_macro::once; /// When no (or expired) cache, concurrent calls /// will synchronize (`sync_writes`) so the function /// is only executed once. -#[once(time=10, option = true, sync_writes = true)] +#[once(time=10, option = true, sync_writes = "default")] fn keyed(a: String) -> Option { if a == "a" { Some(a.len()) diff --git a/tests/cached.rs b/tests/cached.rs index 8a161bd..0d33e04 100644 --- a/tests/cached.rs +++ b/tests/cached.rs @@ -10,7 +10,7 @@ use cached::{ }; use serial_test::serial; use std::thread::{self, sleep}; -use std::time::Duration; +use std::time::{Duration, Instant}; cached! { UNBOUND_FIB; @@ -848,7 +848,7 @@ async fn test_only_cached_option_once_per_second_a() { /// to return the cached result of the one call instead of all /// concurrently un-cached tasks executing and writing concurrently. #[cfg(feature = "async")] -#[once(time = 2, sync_writes = true)] +#[once(time = 2, sync_writes = "default")] async fn only_cached_once_per_second_sync_writes(s: String) -> Vec { vec![s] } @@ -862,7 +862,7 @@ async fn test_only_cached_once_per_second_sync_writes() { assert_eq!(a.await.unwrap(), b.await.unwrap()); } -#[cached(time = 2, sync_writes = true, key = "u32", convert = "{ 1 }")] +#[cached(time = 2, sync_writes = "default", key = "u32", convert = "{ 1 }")] fn cached_sync_writes(s: String) -> Vec { vec![s] } @@ -881,7 +881,7 @@ fn test_cached_sync_writes() { } #[cfg(feature = "async")] -#[cached(time = 2, sync_writes = true, key = "u32", convert = "{ 1 }")] +#[cached(time = 2, sync_writes = "default", key = "u32", convert = "{ 1 }")] async fn cached_sync_writes_a(s: String) -> Vec { vec![s] } @@ -898,8 +898,51 @@ async fn test_cached_sync_writes_a() { assert_eq!(a, c.await.unwrap()); } +#[cached(time = 2, sync_writes = "by_key", key = "u32", convert = "{ 1 }")] +fn cached_sync_writes_by_key(s: String) -> Vec { + sleep(Duration::new(1, 0)); + vec![s] +} + +#[test] +fn test_cached_sync_writes_by_key() { + let a = std::thread::spawn(|| cached_sync_writes_by_key("a".to_string())); + let b = std::thread::spawn(|| cached_sync_writes_by_key("b".to_string())); + let c = std::thread::spawn(|| cached_sync_writes_by_key("c".to_string())); + let start = Instant::now(); + let a = a.join().unwrap(); + let b = b.join().unwrap(); + let c = c.join().unwrap(); + assert!(start.elapsed() < Duration::from_secs(2)); +} + +#[cfg(feature = "async")] +#[cached( + time = 5, + sync_writes = "by_key", + key = "String", + convert = r#"{ format!("{}", s) }"# +)] +async fn cached_sync_writes_by_key_a(s: String) -> Vec { + tokio::time::sleep(Duration::from_secs(1)).await; + vec![s] +} + +#[cfg(feature = "async")] +#[tokio::test] +async fn test_cached_sync_writes_by_key_a() { + let a = tokio::spawn(cached_sync_writes_by_key_a("a".to_string())); + let b = tokio::spawn(cached_sync_writes_by_key_a("b".to_string())); + let c = tokio::spawn(cached_sync_writes_by_key_a("c".to_string())); + let start = Instant::now(); + a.await.unwrap(); + b.await.unwrap(); + c.await.unwrap(); + assert!(start.elapsed() < Duration::from_secs(2)); +} + #[cfg(feature = "async")] -#[once(sync_writes = true)] +#[once(sync_writes = "default")] async fn once_sync_writes_a(s: &tokio::sync::Mutex) -> String { let mut guard = s.lock().await; let results: String = (*guard).clone().to_string();