From 75cfeb1854ab0da78e7049325a0d7e3d6bb3f7ee Mon Sep 17 00:00:00 2001 From: lcnr Date: Fri, 8 Dec 2023 14:28:57 +0100 Subject: [PATCH 1/2] cleanup type variable storage --- compiler/rustc_hir_typeck/src/fallback.rs | 18 ++-- .../src/infer/canonical/query_response.rs | 5 +- compiler/rustc_infer/src/infer/generalize.rs | 2 +- compiler/rustc_infer/src/infer/mod.rs | 20 ++--- compiler/rustc_infer/src/infer/resolve.rs | 2 +- .../rustc_infer/src/infer/type_variable.rs | 87 ++++--------------- compiler/rustc_infer/src/infer/undo_log.rs | 6 +- 7 files changed, 43 insertions(+), 97 deletions(-) diff --git a/compiler/rustc_hir_typeck/src/fallback.rs b/compiler/rustc_hir_typeck/src/fallback.rs index 38b780367e693..47add0435989d 100644 --- a/compiler/rustc_hir_typeck/src/fallback.rs +++ b/compiler/rustc_hir_typeck/src/fallback.rs @@ -57,19 +57,21 @@ impl<'tcx> FnCtxt<'_, 'tcx> { } fn fallback_types(&self) -> bool { - // Check if we have any unsolved variables. If not, no need for fallback. - let unsolved_variables = self.unsolved_variables(); + // Check if we have any unresolved variables. If not, no need for fallback. + let unresolved_variables = self.unresolved_variables(); - if unsolved_variables.is_empty() { + if unresolved_variables.is_empty() { return false; } - let diverging_fallback = self.calculate_diverging_fallback(&unsolved_variables); + let diverging_fallback = self.calculate_diverging_fallback(&unresolved_variables); // We do fallback in two passes, to try to generate // better error messages. // The first time, we do *not* replace opaque types. - for ty in unsolved_variables { + // + // TODO: We return `true` even if no fallback occurs. + for ty in unresolved_variables { debug!("unsolved_variable = {:?}", ty); self.fallback_if_possible(ty, &diverging_fallback); } @@ -230,9 +232,9 @@ impl<'tcx> FnCtxt<'_, 'tcx> { /// any variable that has an edge into `D`. fn calculate_diverging_fallback( &self, - unsolved_variables: &[Ty<'tcx>], + unresolved_variables: &[Ty<'tcx>], ) -> UnordMap, Ty<'tcx>> { - debug!("calculate_diverging_fallback({:?})", unsolved_variables); + debug!("calculate_diverging_fallback({:?})", unresolved_variables); // Construct a coercion graph where an edge `A -> B` indicates // a type variable is that is coerced @@ -240,7 +242,7 @@ impl<'tcx> FnCtxt<'_, 'tcx> { // Extract the unsolved type inference variable vids; note that some // unsolved variables are integer/float variables and are excluded. - let unsolved_vids = unsolved_variables.iter().filter_map(|ty| ty.ty_vid()); + let unsolved_vids = unresolved_variables.iter().filter_map(|ty| ty.ty_vid()); // Compute the diverging root vids D -- that is, the root vid of // those type variables that (a) are the target of a coercion from diff --git a/compiler/rustc_infer/src/infer/canonical/query_response.rs b/compiler/rustc_infer/src/infer/canonical/query_response.rs index 6860d0de179ea..8cca4c6231fc2 100644 --- a/compiler/rustc_infer/src/infer/canonical/query_response.rs +++ b/compiler/rustc_infer/src/infer/canonical/query_response.rs @@ -166,10 +166,7 @@ impl<'tcx> InferCtxt<'tcx> { } fn take_opaque_types_for_query_response(&self) -> Vec<(ty::OpaqueTypeKey<'tcx>, Ty<'tcx>)> { - std::mem::take(&mut self.inner.borrow_mut().opaque_type_storage.opaque_types) - .into_iter() - .map(|(k, v)| (k, v.hidden_type.ty)) - .collect() + self.take_opaque_types().into_iter().map(|(k, v)| (k, v.hidden_type.ty)).collect() } /// Given the (canonicalized) result to a canonical query, diff --git a/compiler/rustc_infer/src/infer/generalize.rs b/compiler/rustc_infer/src/infer/generalize.rs index 4b4017cec57ab..383f3bdbe23bc 100644 --- a/compiler/rustc_infer/src/infer/generalize.rs +++ b/compiler/rustc_infer/src/infer/generalize.rs @@ -295,7 +295,7 @@ where ty::Covariant | ty::Contravariant => (), } - let origin = *inner.type_variables().var_origin(vid); + let origin = inner.type_variables().var_origin(vid); let new_var_id = inner.type_variables().new_var(self.for_universe, origin); let u = Ty::new_var(self.tcx(), new_var_id); diff --git a/compiler/rustc_infer/src/infer/mod.rs b/compiler/rustc_infer/src/infer/mod.rs index 32c09e491c7ed..1aebf8cde6139 100644 --- a/compiler/rustc_infer/src/infer/mod.rs +++ b/compiler/rustc_infer/src/infer/mod.rs @@ -98,6 +98,8 @@ pub(crate) type UnificationTable<'a, 'tcx, T> = ut::UnificationTable< /// call to `start_snapshot` and `rollback_to`. #[derive(Clone)] pub struct InferCtxtInner<'tcx> { + undo_log: InferCtxtUndoLogs<'tcx>, + /// Cache for projections. /// /// This cache is snapshotted along with the infcx. @@ -162,8 +164,6 @@ pub struct InferCtxtInner<'tcx> { /// that all type inference variables have been bound and so forth. region_obligations: Vec>, - undo_log: InferCtxtUndoLogs<'tcx>, - /// Caches for opaque type inference. opaque_type_storage: OpaqueTypeStorage<'tcx>, } @@ -171,9 +171,10 @@ pub struct InferCtxtInner<'tcx> { impl<'tcx> InferCtxtInner<'tcx> { fn new() -> InferCtxtInner<'tcx> { InferCtxtInner { + undo_log: InferCtxtUndoLogs::default(), + projection_cache: Default::default(), type_variable_storage: type_variable::TypeVariableStorage::new(), - undo_log: InferCtxtUndoLogs::default(), const_unification_storage: ut::UnificationTableStorage::new(), int_unification_storage: ut::UnificationTableStorage::new(), float_unification_storage: ut::UnificationTableStorage::new(), @@ -759,7 +760,7 @@ impl<'tcx> InferCtxt<'tcx> { pub fn type_var_origin(&self, ty: Ty<'tcx>) -> Option { match *ty.kind() { ty::Infer(ty::TyVar(vid)) => { - Some(*self.inner.borrow_mut().type_variables().var_origin(vid)) + Some(self.inner.borrow_mut().type_variables().var_origin(vid)) } _ => None, } @@ -769,11 +770,11 @@ impl<'tcx> InferCtxt<'tcx> { freshen::TypeFreshener::new(self) } - pub fn unsolved_variables(&self) -> Vec> { + pub fn unresolved_variables(&self) -> Vec> { let mut inner = self.inner.borrow_mut(); let mut vars: Vec> = inner .type_variables() - .unsolved_variables() + .unresolved_variables() .into_iter() .map(|t| Ty::new_var(self.tcx, t)) .collect(); @@ -1282,12 +1283,7 @@ impl<'tcx> InferCtxt<'tcx> { pub fn region_var_origin(&self, vid: ty::RegionVid) -> RegionVariableOrigin { let mut inner = self.inner.borrow_mut(); let inner = &mut *inner; - inner - .region_constraint_storage - .as_mut() - .expect("regions already resolved") - .with_log(&mut inner.undo_log) - .var_origin(vid) + inner.unwrap_region_constraints().var_origin(vid) } /// Clone the list of variable regions. This is used only during NLL processing diff --git a/compiler/rustc_infer/src/infer/resolve.rs b/compiler/rustc_infer/src/infer/resolve.rs index ece30bbba12ac..f317ccee6918a 100644 --- a/compiler/rustc_infer/src/infer/resolve.rs +++ b/compiler/rustc_infer/src/infer/resolve.rs @@ -134,7 +134,7 @@ impl<'a, 'tcx> TypeVisitor> for UnresolvedTypeOrConstFinder<'a, 'tc if let TypeVariableOrigin { kind: TypeVariableOriginKind::TypeParameterDefinition(_, _), span, - } = *ty_vars.var_origin(ty_vid) + } = ty_vars.var_origin(ty_vid) { Some(span) } else { diff --git a/compiler/rustc_infer/src/infer/type_variable.rs b/compiler/rustc_infer/src/infer/type_variable.rs index bd6f905c8241d..58b8110157bfb 100644 --- a/compiler/rustc_infer/src/infer/type_variable.rs +++ b/compiler/rustc_infer/src/infer/type_variable.rs @@ -1,4 +1,5 @@ use rustc_hir::def_id::DefId; +use rustc_index::IndexVec; use rustc_middle::ty::{self, Ty, TyVid}; use rustc_span::symbol::Symbol; use rustc_span::Span; @@ -11,14 +12,13 @@ use std::cmp; use std::marker::PhantomData; use std::ops::Range; -use rustc_data_structures::undo_log::{Rollback, UndoLogs}; +use rustc_data_structures::undo_log::Rollback; /// Represents a single undo-able action that affects a type inference variable. #[derive(Clone)] pub(crate) enum UndoLog<'tcx> { EqRelation(sv::UndoLog>>), SubRelation(sv::UndoLog>), - Values(sv::UndoLog), } /// Convert from a specific kind of undo to the more general UndoLog @@ -35,34 +35,19 @@ impl<'tcx> From>> for UndoLog<'tcx> { } } -/// Convert from a specific kind of undo to the more general UndoLog -impl<'tcx> From> for UndoLog<'tcx> { - fn from(l: sv::UndoLog) -> Self { - UndoLog::Values(l) - } -} - -/// Convert from a specific kind of undo to the more general UndoLog -impl<'tcx> From for UndoLog<'tcx> { - fn from(l: Instantiate) -> Self { - UndoLog::Values(sv::UndoLog::Other(l)) - } -} - impl<'tcx> Rollback> for TypeVariableStorage<'tcx> { fn reverse(&mut self, undo: UndoLog<'tcx>) { match undo { UndoLog::EqRelation(undo) => self.eq_relations.reverse(undo), UndoLog::SubRelation(undo) => self.sub_relations.reverse(undo), - UndoLog::Values(undo) => self.values.reverse(undo), } } } #[derive(Clone)] pub struct TypeVariableStorage<'tcx> { - values: sv::SnapshotVecStorage, - + /// The origins of each type variable. + values: IndexVec, /// Two variables are unified in `eq_relations` when we have a /// constraint `?X == ?Y`. This table also stores, for each key, /// the known value. @@ -168,15 +153,10 @@ impl<'tcx> TypeVariableValue<'tcx> { } } -#[derive(Clone)] -pub(crate) struct Instantiate; - -pub(crate) struct Delegate; - impl<'tcx> TypeVariableStorage<'tcx> { pub fn new() -> TypeVariableStorage<'tcx> { TypeVariableStorage { - values: sv::SnapshotVecStorage::new(), + values: Default::default(), eq_relations: ut::UnificationTableStorage::new(), sub_relations: ut::UnificationTableStorage::new(), } @@ -194,6 +174,11 @@ impl<'tcx> TypeVariableStorage<'tcx> { pub(crate) fn eq_relations_ref(&self) -> &ut::UnificationTableStorage> { &self.eq_relations } + + pub(super) fn finalize_rollback(&mut self) { + debug_assert!(self.values.len() >= self.eq_relations.len()); + self.values.truncate(self.eq_relations.len()); + } } impl<'tcx> TypeVariableTable<'_, 'tcx> { @@ -201,8 +186,8 @@ impl<'tcx> TypeVariableTable<'_, 'tcx> { /// /// Note that this function does not return care whether /// `vid` has been unified with something else or not. - pub fn var_origin(&self, vid: ty::TyVid) -> &TypeVariableOrigin { - &self.storage.values.get(vid.as_usize()).origin + pub fn var_origin(&self, vid: ty::TyVid) -> TypeVariableOrigin { + self.storage.values[vid].origin } /// Records that `a == b`, depending on `dir`. @@ -237,11 +222,6 @@ impl<'tcx> TypeVariableTable<'_, 'tcx> { self.eq_relations().probe_value(vid) ); self.eq_relations().union_value(vid, TypeVariableValue::Known { value: ty }); - - // Hack: we only need this so that `types_escaping_snapshot` - // can see what has been unified; see the Delegate impl for - // more details. - self.undo_log.push(Instantiate); } /// Creates a new type variable. @@ -262,14 +242,14 @@ impl<'tcx> TypeVariableTable<'_, 'tcx> { let eq_key = self.eq_relations().new_key(TypeVariableValue::Unknown { universe }); let sub_key = self.sub_relations().new_key(()); - assert_eq!(eq_key.vid, sub_key); + debug_assert_eq!(eq_key.vid, sub_key); - let index = self.values().push(TypeVariableData { origin }); - assert_eq!(eq_key.vid.as_u32(), index as u32); + let index = self.storage.values.push(TypeVariableData { origin }); + debug_assert_eq!(eq_key.vid, index); debug!("new_var(index={:?}, universe={:?}, origin={:?})", eq_key.vid, universe, origin); - eq_key.vid + index } /// Returns the number of type variables created thus far. @@ -329,13 +309,6 @@ impl<'tcx> TypeVariableTable<'_, 'tcx> { } } - #[inline] - fn values( - &mut self, - ) -> sv::SnapshotVec, &mut InferCtxtUndoLogs<'tcx>> { - self.storage.values.with_log(self.undo_log) - } - #[inline] fn eq_relations(&mut self) -> super::UnificationTable<'_, 'tcx, TyVidEqKey<'tcx>> { self.storage.eq_relations.with_log(self.undo_log) @@ -354,16 +327,14 @@ impl<'tcx> TypeVariableTable<'_, 'tcx> { let range = TyVid::from_usize(value_count)..TyVid::from_usize(self.num_vars()); ( range.start..range.end, - (range.start.as_usize()..range.end.as_usize()) - .map(|index| self.storage.values.get(index).origin) - .collect(), + (range.start..range.end).map(|index| self.var_origin(index)).collect(), ) } /// Returns indices of all variables that are not yet /// instantiated. - pub fn unsolved_variables(&mut self) -> Vec { - (0..self.storage.values.len()) + pub fn unresolved_variables(&mut self) -> Vec { + (0..self.num_vars()) .filter_map(|i| { let vid = ty::TyVid::from_usize(i); match self.probe(vid) { @@ -375,26 +346,6 @@ impl<'tcx> TypeVariableTable<'_, 'tcx> { } } -impl sv::SnapshotVecDelegate for Delegate { - type Value = TypeVariableData; - type Undo = Instantiate; - - fn reverse(_values: &mut Vec, _action: Instantiate) { - // We don't actually have to *do* anything to reverse an - // instantiation; the value for a variable is stored in the - // `eq_relations` and hence its rollback code will handle - // it. In fact, we could *almost* just remove the - // `SnapshotVec` entirely, except that we would have to - // reproduce *some* of its logic, since we want to know which - // type variables have been instantiated since the snapshot - // was started, so we can implement `types_escaping_snapshot`. - // - // (If we extended the `UnificationTable` to let us see which - // values have been unified and so forth, that might also - // suffice.) - } -} - /////////////////////////////////////////////////////////////////////////// /// These structs (a newtyped TyVid) are used as the unification key diff --git a/compiler/rustc_infer/src/infer/undo_log.rs b/compiler/rustc_infer/src/infer/undo_log.rs index 5655730518e5d..be02452d89fbf 100644 --- a/compiler/rustc_infer/src/infer/undo_log.rs +++ b/compiler/rustc_infer/src/infer/undo_log.rs @@ -32,7 +32,7 @@ pub(crate) enum UndoLog<'tcx> { } macro_rules! impl_from { - ($($ctor: ident ($ty: ty),)*) => { + ($($ctor:ident ($ty:ty),)*) => { $( impl<'tcx> From<$ty> for UndoLog<'tcx> { fn from(x: $ty) -> Self { @@ -50,8 +50,6 @@ impl_from! { TypeVariables(sv::UndoLog>>), TypeVariables(sv::UndoLog>), - TypeVariables(sv::UndoLog), - TypeVariables(type_variable::Instantiate), IntUnificationTable(sv::UndoLog>), @@ -140,6 +138,8 @@ impl<'tcx> InferCtxtInner<'tcx> { self.reverse(undo); } + self.type_variable_storage.finalize_rollback(); + if self.undo_log.num_open_snapshots == 1 { // After the root snapshot the undo log should be empty. assert!(snapshot.undo_len == 0); From 929782658e3dd7dc6b9aa3ce0868f64e6e5e3e23 Mon Sep 17 00:00:00 2001 From: lcnr Date: Fri, 8 Dec 2023 14:50:07 +0100 Subject: [PATCH 2/2] only return true in `fallback_types' if fallback has occurred --- compiler/rustc_hir_typeck/src/fallback.rs | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/compiler/rustc_hir_typeck/src/fallback.rs b/compiler/rustc_hir_typeck/src/fallback.rs index 47add0435989d..023bd70be174e 100644 --- a/compiler/rustc_hir_typeck/src/fallback.rs +++ b/compiler/rustc_hir_typeck/src/fallback.rs @@ -24,9 +24,9 @@ impl<'tcx> FnCtxt<'_, 'tcx> { self.fulfillment_cx.borrow_mut().pending_obligations() ); - let fallback_occured = self.fallback_types() | self.fallback_effects(); + let fallback_occurred = self.fallback_types() | self.fallback_effects(); - if !fallback_occured { + if !fallback_occurred { return; } @@ -69,14 +69,13 @@ impl<'tcx> FnCtxt<'_, 'tcx> { // We do fallback in two passes, to try to generate // better error messages. // The first time, we do *not* replace opaque types. - // - // TODO: We return `true` even if no fallback occurs. + let mut fallback_occurred = false; for ty in unresolved_variables { debug!("unsolved_variable = {:?}", ty); - self.fallback_if_possible(ty, &diverging_fallback); + fallback_occurred |= self.fallback_if_possible(ty, &diverging_fallback); } - true + fallback_occurred } fn fallback_effects(&self) -> bool { @@ -86,9 +85,8 @@ impl<'tcx> FnCtxt<'_, 'tcx> { return false; } - // not setting `fallback_has_occured` here because that field is only used for type fallback - // diagnostics. - + // not setting the `fallback_has_occured` field here because + // that field is only used for type fallback diagnostics. for effect in unsolved_effects { let expected = self.tcx.consts.true_; let cause = self.misc(rustc_span::DUMMY_SP); @@ -124,7 +122,7 @@ impl<'tcx> FnCtxt<'_, 'tcx> { &self, ty: Ty<'tcx>, diverging_fallback: &UnordMap, Ty<'tcx>>, - ) { + ) -> bool { // Careful: we do NOT shallow-resolve `ty`. We know that `ty` // is an unsolved variable, and we determine its fallback // based solely on how it was created, not what other type @@ -149,7 +147,7 @@ impl<'tcx> FnCtxt<'_, 'tcx> { ty::Infer(ty::FloatVar(_)) => self.tcx.types.f64, _ => match diverging_fallback.get(&ty) { Some(&fallback_ty) => fallback_ty, - None => return, + None => return false, }, }; debug!("fallback_if_possible(ty={:?}): defaulting to `{:?}`", ty, fallback); @@ -161,6 +159,7 @@ impl<'tcx> FnCtxt<'_, 'tcx> { .unwrap_or(rustc_span::DUMMY_SP); self.demand_eqtype(span, ty, fallback); self.fallback_has_occurred.set(true); + true } /// The "diverging fallback" system is rather complicated. This is