Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove OnHit callback from query caches. #107667

Merged
merged 2 commits into from
Feb 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 20 additions & 40 deletions compiler/rustc_middle/src/ty/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,16 +106,6 @@ impl<'tcx> TyCtxt<'tcx> {
}
}

/// Helper for `TyCtxtEnsure` to avoid a closure.
#[inline(always)]
fn noop<T>(_: &T) {}

/// Helper to ensure that queries only return `Copy` types.
#[inline(always)]
fn copy<T: Copy>(x: &T) -> T {
*x
}

macro_rules! query_helper_param_ty {
(DefId) => { impl IntoQueryParam<DefId> };
(LocalDefId) => { impl IntoQueryParam<LocalDefId> };
Expand Down Expand Up @@ -225,14 +215,10 @@ macro_rules! define_callbacks {
let key = key.into_query_param();
opt_remap_env_constness!([$($modifiers)*][key]);

let cached = try_get_cached(self.tcx, &self.tcx.query_caches.$name, &key, noop);

match cached {
Ok(()) => return,
Err(()) => (),
}

self.tcx.queries.$name(self.tcx, DUMMY_SP, key, QueryMode::Ensure);
match try_get_cached(self.tcx, &self.tcx.query_caches.$name, &key) {
Some(_) => return,
None => self.tcx.queries.$name(self.tcx, DUMMY_SP, key, QueryMode::Ensure),
};
})*
}

Expand All @@ -254,14 +240,10 @@ macro_rules! define_callbacks {
let key = key.into_query_param();
opt_remap_env_constness!([$($modifiers)*][key]);

let cached = try_get_cached(self.tcx, &self.tcx.query_caches.$name, &key, copy);

match cached {
Ok(value) => return value,
Err(()) => (),
match try_get_cached(self.tcx, &self.tcx.query_caches.$name, &key) {
Some(value) => value,
None => self.tcx.queries.$name(self.tcx, self.span, key, QueryMode::Get).unwrap(),
}

self.tcx.queries.$name(self.tcx, self.span, key, QueryMode::Get).unwrap()
})*
}

Expand Down Expand Up @@ -353,27 +335,25 @@ macro_rules! define_feedable {
let tcx = self.tcx;
let cache = &tcx.query_caches.$name;

let cached = try_get_cached(tcx, cache, &key, copy);

match cached {
Ok(old) => {
match try_get_cached(tcx, cache, &key) {
Some(old) => {
bug!(
"Trying to feed an already recorded value for query {} key={key:?}:\nold value: {old:?}\nnew value: {value:?}",
stringify!($name),
)
}
None => {
let dep_node = dep_graph::DepNode::construct(tcx, dep_graph::DepKind::$name, &key);
let dep_node_index = tcx.dep_graph.with_feed_task(
dep_node,
tcx,
key,
&value,
hash_result!([$($modifiers)*]),
);
cache.complete(key, value, dep_node_index)
}
Err(()) => (),
}

let dep_node = dep_graph::DepNode::construct(tcx, dep_graph::DepKind::$name, &key);
let dep_node_index = tcx.dep_graph.with_feed_task(
dep_node,
tcx,
key,
&value,
hash_result!([$($modifiers)*]),
);
cache.complete(key, value, dep_node_index)
}
})*
}
Expand Down
73 changes: 17 additions & 56 deletions compiler/rustc_query_system/src/query/caches.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ use std::marker::PhantomData;
pub trait CacheSelector<'tcx, V> {
type Cache
where
V: Clone;
V: Copy;
type ArenaCache;
}

pub trait QueryStorage {
type Value: Debug;
type Stored: Clone;
type Stored: Copy;

/// Store a value without putting it in the cache.
/// This is meant to be used with cycle errors.
Expand All @@ -36,14 +36,7 @@ pub trait QueryCache: QueryStorage + Sized {
/// It returns the shard index and a lock guard to the shard,
/// which will be used if the query is not in the cache and we need
/// to compute it.
fn lookup<R, OnHit>(
&self,
key: &Self::Key,
// `on_hit` can be called while holding a lock to the query state shard.
on_hit: OnHit,
) -> Result<R, ()>
where
OnHit: FnOnce(&Self::Stored, DepNodeIndex) -> R;
fn lookup(&self, key: &Self::Key) -> Option<(Self::Stored, DepNodeIndex)>;

fn complete(&self, key: Self::Key, value: Self::Value, index: DepNodeIndex) -> Self::Stored;

Expand All @@ -55,7 +48,7 @@ pub struct DefaultCacheSelector<K>(PhantomData<K>);
impl<'tcx, K: Eq + Hash, V: 'tcx> CacheSelector<'tcx, V> for DefaultCacheSelector<K> {
type Cache = DefaultCache<K, V>
where
V: Clone;
V: Copy;
type ArenaCache = ArenaCache<'tcx, K, V>;
}

Expand All @@ -72,7 +65,7 @@ impl<K, V> Default for DefaultCache<K, V> {
}
}

impl<K: Eq + Hash, V: Clone + Debug> QueryStorage for DefaultCache<K, V> {
impl<K: Eq + Hash, V: Copy + Debug> QueryStorage for DefaultCache<K, V> {
type Value = V;
type Stored = V;

Expand All @@ -86,28 +79,20 @@ impl<K: Eq + Hash, V: Clone + Debug> QueryStorage for DefaultCache<K, V> {
impl<K, V> QueryCache for DefaultCache<K, V>
where
K: Eq + Hash + Clone + Debug,
V: Clone + Debug,
V: Copy + Debug,
{
type Key = K;

#[inline(always)]
fn lookup<R, OnHit>(&self, key: &K, on_hit: OnHit) -> Result<R, ()>
where
OnHit: FnOnce(&V, DepNodeIndex) -> R,
{
fn lookup(&self, key: &K) -> Option<(V, DepNodeIndex)> {
let key_hash = sharded::make_hash(key);
#[cfg(parallel_compiler)]
let lock = self.cache.get_shard_by_hash(key_hash).lock();
#[cfg(not(parallel_compiler))]
let lock = self.cache.lock();
let result = lock.raw_entry().from_key_hashed_nocheck(key_hash, key);

if let Some((_, value)) = result {
let hit_result = on_hit(&value.0, value.1);
Ok(hit_result)
} else {
Err(())
}
if let Some((_, value)) = result { Some(*value) } else { None }
}

#[inline]
Expand Down Expand Up @@ -176,23 +161,15 @@ where
type Key = K;

#[inline(always)]
fn lookup<R, OnHit>(&self, key: &K, on_hit: OnHit) -> Result<R, ()>
where
OnHit: FnOnce(&&'tcx V, DepNodeIndex) -> R,
{
fn lookup(&self, key: &K) -> Option<(&'tcx V, DepNodeIndex)> {
let key_hash = sharded::make_hash(key);
#[cfg(parallel_compiler)]
let lock = self.cache.get_shard_by_hash(key_hash).lock();
#[cfg(not(parallel_compiler))]
let lock = self.cache.lock();
let result = lock.raw_entry().from_key_hashed_nocheck(key_hash, key);

if let Some((_, value)) = result {
let hit_result = on_hit(&&value.0, value.1);
Ok(hit_result)
} else {
Err(())
}
if let Some((_, value)) = result { Some((&value.0, value.1)) } else { None }
}

#[inline]
Expand Down Expand Up @@ -234,7 +211,7 @@ pub struct VecCacheSelector<K>(PhantomData<K>);
impl<'tcx, K: Idx, V: 'tcx> CacheSelector<'tcx, V> for VecCacheSelector<K> {
type Cache = VecCache<K, V>
where
V: Clone;
V: Copy;
type ArenaCache = VecArenaCache<'tcx, K, V>;
}

Expand All @@ -251,7 +228,7 @@ impl<K: Idx, V> Default for VecCache<K, V> {
}
}

impl<K: Eq + Idx, V: Clone + Debug> QueryStorage for VecCache<K, V> {
impl<K: Eq + Idx, V: Copy + Debug> QueryStorage for VecCache<K, V> {
type Value = V;
type Stored = V;

Expand All @@ -265,25 +242,17 @@ impl<K: Eq + Idx, V: Clone + Debug> QueryStorage for VecCache<K, V> {
impl<K, V> QueryCache for VecCache<K, V>
where
K: Eq + Idx + Clone + Debug,
V: Clone + Debug,
V: Copy + Debug,
{
type Key = K;

#[inline(always)]
fn lookup<R, OnHit>(&self, key: &K, on_hit: OnHit) -> Result<R, ()>
where
OnHit: FnOnce(&V, DepNodeIndex) -> R,
{
fn lookup(&self, key: &K) -> Option<(V, DepNodeIndex)> {
#[cfg(parallel_compiler)]
let lock = self.cache.get_shard_by_hash(key.index() as u64).lock();
#[cfg(not(parallel_compiler))]
let lock = self.cache.lock();
if let Some(Some(value)) = lock.get(*key) {
let hit_result = on_hit(&value.0, value.1);
Ok(hit_result)
} else {
Err(())
}
if let Some(Some(value)) = lock.get(*key) { Some(*value) } else { None }
}

#[inline]
Expand Down Expand Up @@ -357,20 +326,12 @@ where
type Key = K;

#[inline(always)]
fn lookup<R, OnHit>(&self, key: &K, on_hit: OnHit) -> Result<R, ()>
where
OnHit: FnOnce(&&'tcx V, DepNodeIndex) -> R,
{
fn lookup(&self, key: &K) -> Option<(&'tcx V, DepNodeIndex)> {
#[cfg(parallel_compiler)]
let lock = self.cache.get_shard_by_hash(key.index() as u64).lock();
#[cfg(not(parallel_compiler))]
let lock = self.cache.lock();
if let Some(Some(value)) = lock.get(*key) {
let hit_result = on_hit(&&value.0, value.1);
Ok(hit_result)
} else {
Err(())
}
if let Some(Some(value)) = lock.get(*key) { Some((&value.0, value.1)) } else { None }
}

#[inline]
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_query_system/src/query/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pub trait QueryConfig<Qcx: QueryContext> {

type Key: DepNodeParams<Qcx::DepContext> + Eq + Hash + Clone + Debug;
type Value: Debug;
type Stored: Debug + Clone + std::borrow::Borrow<Self::Value>;
type Stored: Debug + Copy + std::borrow::Borrow<Self::Value>;

type Cache: QueryCache<Key = Self::Key, Stored = Self::Stored, Value = Self::Value>;

Expand Down
60 changes: 30 additions & 30 deletions compiler/rustc_query_system/src/query/plumbing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ fn mk_cycle<Qcx, V, R, D: DepKind>(
where
Qcx: QueryContext + crate::query::HasDepContext<DepKind = D>,
V: std::fmt::Debug + Value<Qcx::DepContext, Qcx::DepKind>,
R: Clone,
R: Copy,
{
let error = report_cycle(qcx.dep_context().sess(), &cycle_error);
let value = handle_cycle_error(*qcx.dep_context(), &cycle_error, error, handler);
Expand Down Expand Up @@ -339,25 +339,21 @@ where
/// which will be used if the query is not in the cache and we need
/// to compute it.
#[inline]
pub fn try_get_cached<Tcx, C, R, OnHit>(
tcx: Tcx,
cache: &C,
key: &C::Key,
// `on_hit` can be called while holding a lock to the query cache
on_hit: OnHit,
) -> Result<R, ()>
pub fn try_get_cached<Tcx, C>(tcx: Tcx, cache: &C, key: &C::Key) -> Option<C::Stored>
where
C: QueryCache,
Tcx: DepContext,
OnHit: FnOnce(&C::Stored) -> R,
{
cache.lookup(&key, |value, index| {
if std::intrinsics::unlikely(tcx.profiler().enabled()) {
tcx.profiler().query_cache_hit(index.into());
match cache.lookup(&key) {
Some((value, index)) => {
if std::intrinsics::unlikely(tcx.profiler().enabled()) {
tcx.profiler().query_cache_hit(index.into());
}
tcx.dep_graph().read_index(index);
Some(value)
}
tcx.dep_graph().read_index(index);
on_hit(value)
})
None => None,
}
}

fn try_execute_query<Q, Qcx>(
Expand All @@ -379,17 +375,25 @@ where
if Q::FEEDABLE {
// We may have put a value inside the cache from inside the execution.
// Verify that it has the same hash as what we have now, to ensure consistency.
let _ = cache.lookup(&key, |cached_result, _| {
if let Some((cached_result, _)) = cache.lookup(&key) {
let hasher = Q::HASH_RESULT.expect("feedable forbids no_hash");

let old_hash = qcx.dep_context().with_stable_hashing_context(|mut hcx| hasher(&mut hcx, cached_result.borrow()));
let new_hash = qcx.dep_context().with_stable_hashing_context(|mut hcx| hasher(&mut hcx, &result));
let old_hash = qcx.dep_context().with_stable_hashing_context(|mut hcx| {
hasher(&mut hcx, cached_result.borrow())
});
let new_hash = qcx
.dep_context()
.with_stable_hashing_context(|mut hcx| hasher(&mut hcx, &result));
debug_assert_eq!(
old_hash, new_hash,
old_hash,
new_hash,
"Computed query value for {:?}({:?}) is inconsistent with fed value,\ncomputed={:#?}\nfed={:#?}",
Q::DEP_KIND, key, result, cached_result,
Q::DEP_KIND,
key,
result,
cached_result,
);
});
}
}
let result = job.complete(cache, result, dep_node_index);
(result, Some(dep_node_index))
Expand All @@ -400,9 +404,9 @@ where
}
#[cfg(parallel_compiler)]
TryGetJob::JobCompleted(query_blocked_prof_timer) => {
let (v, index) = cache
.lookup(&key, |value, index| (value.clone(), index))
.unwrap_or_else(|_| panic!("value must be in cache after waiting"));
let Some((v, index)) = cache.lookup(&key) else {
panic!("value must be in cache after waiting")
};

if std::intrinsics::unlikely(qcx.dep_context().profiler().enabled()) {
qcx.dep_context().profiler().query_cache_hit(index.into());
Expand Down Expand Up @@ -771,15 +775,11 @@ where
// We may be concurrently trying both execute and force a query.
// Ensure that only one of them runs the query.
let cache = Q::query_cache(qcx);
let cached = cache.lookup(&key, |_, index| {
if let Some((_, index)) = cache.lookup(&key) {
if std::intrinsics::unlikely(qcx.dep_context().profiler().enabled()) {
qcx.dep_context().profiler().query_cache_hit(index.into());
}
});

match cached {
Ok(()) => return,
Err(()) => {}
return;
}

let state = Q::query_state(qcx);
Expand Down