diff --git a/compiler/rustc_middle/src/mir/mod.rs b/compiler/rustc_middle/src/mir/mod.rs index 6f42b69633c84..c8417ea6c3a65 100644 --- a/compiler/rustc_middle/src/mir/mod.rs +++ b/compiler/rustc_middle/src/mir/mod.rs @@ -1644,6 +1644,14 @@ impl<'tcx> PlaceRef<'tcx> { } } + /// Returns `true` if this `Place` contains a `Deref` projection. + /// + /// If `Place::is_indirect` returns false, the caller knows that the `Place` refers to the + /// same region of memory as its base. + pub fn is_indirect(&self) -> bool { + self.projection.iter().any(|elem| elem.is_indirect()) + } + /// If MirPhase >= Derefered and if projection contains Deref, /// It's guaranteed to be in the first place pub fn has_deref(&self) -> bool { diff --git a/compiler/rustc_mir_dataflow/src/impls/borrowed_locals.rs b/compiler/rustc_mir_dataflow/src/impls/borrowed_locals.rs index 0f8e86d1d6679..6f4e7fd4682c1 100644 --- a/compiler/rustc_mir_dataflow/src/impls/borrowed_locals.rs +++ b/compiler/rustc_mir_dataflow/src/impls/borrowed_locals.rs @@ -121,7 +121,9 @@ where // for now. See discussion on [#61069]. // // [#61069]: https://github.com/rust-lang/rust/pull/61069 - self.trans.gen(dropped_place.local); + if !dropped_place.is_indirect() { + self.trans.gen(dropped_place.local); + } } TerminatorKind::Abort diff --git a/compiler/rustc_mir_dataflow/src/lib.rs b/compiler/rustc_mir_dataflow/src/lib.rs index 7f40cfca32fff..3e382f500afbe 100644 --- a/compiler/rustc_mir_dataflow/src/lib.rs +++ b/compiler/rustc_mir_dataflow/src/lib.rs @@ -1,6 +1,7 @@ #![feature(associated_type_defaults)] #![feature(box_patterns)] #![feature(exact_size_is_empty)] +#![feature(let_chains)] #![feature(min_specialization)] #![feature(once_cell)] #![feature(stmt_expr_attributes)] diff --git a/compiler/rustc_mir_dataflow/src/value_analysis.rs b/compiler/rustc_mir_dataflow/src/value_analysis.rs index a6ef2a742c873..401db890a9810 100644 --- a/compiler/rustc_mir_dataflow/src/value_analysis.rs +++ b/compiler/rustc_mir_dataflow/src/value_analysis.rs @@ -24,7 +24,7 @@ //! - The bottom state denotes uninitialized memory. Because we are only doing a sound approximation //! of the actual execution, we can also use this state for places where access would be UB. //! -//! - The assignment logic in `State::assign_place_idx` assumes that the places are non-overlapping, +//! - The assignment logic in `State::insert_place_idx` assumes that the places are non-overlapping, //! or identical. Note that this refers to place expressions, not memory locations. //! //! - Currently, places that have their reference taken cannot be tracked. Although this would be @@ -35,6 +35,7 @@ use std::fmt::{Debug, Formatter}; use rustc_data_structures::fx::FxHashMap; +use rustc_index::bit_set::BitSet; use rustc_index::vec::IndexVec; use rustc_middle::mir::visit::{MutatingUseContext, PlaceContext, Visitor}; use rustc_middle::mir::*; @@ -64,10 +65,8 @@ pub trait ValueAnalysis<'tcx> { StatementKind::Assign(box (place, rvalue)) => { self.handle_assign(*place, rvalue, state); } - StatementKind::SetDiscriminant { .. } => { - // Could treat this as writing a constant to a pseudo-place. - // But discriminants are currently not tracked, so we do nothing. - // Related: https://github.com/rust-lang/unsafe-code-guidelines/issues/84 + StatementKind::SetDiscriminant { box ref place, .. } => { + state.flood_discr(place.as_ref(), self.map()); } StatementKind::Intrinsic(box intrinsic) => { self.handle_intrinsic(intrinsic, state); @@ -446,26 +445,51 @@ impl State { } pub fn flood_with(&mut self, place: PlaceRef<'_>, map: &Map, value: V) { - if let Some(root) = map.find(place) { - self.flood_idx_with(root, map, value); - } + let StateData::Reachable(values) = &mut self.0 else { return }; + map.for_each_aliasing_place(place, None, &mut |place| { + if let Some(vi) = map.places[place].value_index { + values[vi] = value.clone(); + } + }); } pub fn flood(&mut self, place: PlaceRef<'_>, map: &Map) { self.flood_with(place, map, V::top()) } - pub fn flood_idx_with(&mut self, place: PlaceIndex, map: &Map, value: V) { + pub fn flood_discr_with(&mut self, place: PlaceRef<'_>, map: &Map, value: V) { let StateData::Reachable(values) = &mut self.0 else { return }; - map.preorder_invoke(place, &mut |place| { + map.for_each_aliasing_place(place, Some(TrackElem::Discriminant), &mut |place| { if let Some(vi) = map.places[place].value_index { values[vi] = value.clone(); } }); } - pub fn flood_idx(&mut self, place: PlaceIndex, map: &Map) { - self.flood_idx_with(place, map, V::top()) + pub fn flood_discr(&mut self, place: PlaceRef<'_>, map: &Map) { + self.flood_discr_with(place, map, V::top()) + } + + /// Low-level method that assigns to a place. + /// This does nothing if the place is not tracked. + /// + /// The target place must have been flooded before calling this method. + pub fn insert_idx(&mut self, target: PlaceIndex, result: ValueOrPlace, map: &Map) { + match result { + ValueOrPlace::Value(value) => self.insert_value_idx(target, value, map), + ValueOrPlace::Place(source) => self.insert_place_idx(target, source, map), + } + } + + /// Low-level method that assigns a value to a place. + /// This does nothing if the place is not tracked. + /// + /// The target place must have been flooded before calling this method. + pub fn insert_value_idx(&mut self, target: PlaceIndex, value: V, map: &Map) { + let StateData::Reachable(values) = &mut self.0 else { return }; + if let Some(value_index) = map.places[target].value_index { + values[value_index] = value; + } } /// Copies `source` to `target`, including all tracked places beneath. @@ -473,50 +497,41 @@ impl State { /// If `target` contains a place that is not contained in `source`, it will be overwritten with /// Top. Also, because this will copy all entries one after another, it may only be used for /// places that are non-overlapping or identical. - pub fn assign_place_idx(&mut self, target: PlaceIndex, source: PlaceIndex, map: &Map) { + /// + /// The target place must have been flooded before calling this method. + fn insert_place_idx(&mut self, target: PlaceIndex, source: PlaceIndex, map: &Map) { let StateData::Reachable(values) = &mut self.0 else { return }; - // If both places are tracked, we copy the value to the target. If the target is tracked, - // but the source is not, we have to invalidate the value in target. If the target is not - // tracked, then we don't have to do anything. + // If both places are tracked, we copy the value to the target. + // If the target is tracked, but the source is not, we do nothing, as invalidation has + // already been performed. if let Some(target_value) = map.places[target].value_index { if let Some(source_value) = map.places[source].value_index { values[target_value] = values[source_value].clone(); - } else { - values[target_value] = V::top(); } } for target_child in map.children(target) { // Try to find corresponding child and recurse. Reasoning is similar as above. let projection = map.places[target_child].proj_elem.unwrap(); if let Some(source_child) = map.projections.get(&(source, projection)) { - self.assign_place_idx(target_child, *source_child, map); - } else { - self.flood_idx(target_child, map); + self.insert_place_idx(target_child, *source_child, map); } } } + /// Helper method to interpret `target = result`. pub fn assign(&mut self, target: PlaceRef<'_>, result: ValueOrPlace, map: &Map) { + self.flood(target, map); if let Some(target) = map.find(target) { - self.assign_idx(target, result, map); - } else { - // We don't track this place nor any projections, assignment can be ignored. + self.insert_idx(target, result, map); } } - pub fn assign_idx(&mut self, target: PlaceIndex, result: ValueOrPlace, map: &Map) { - match result { - ValueOrPlace::Value(value) => { - // First flood the target place in case we also track any projections (although - // this scenario is currently not well-supported by the API). - self.flood_idx(target, map); - let StateData::Reachable(values) = &mut self.0 else { return }; - if let Some(value_index) = map.places[target].value_index { - values[value_index] = value; - } - } - ValueOrPlace::Place(source) => self.assign_place_idx(target, source, map), + /// Helper method for assignments to a discriminant. + pub fn assign_discr(&mut self, target: PlaceRef<'_>, result: ValueOrPlace, map: &Map) { + self.flood_discr(target, map); + if let Some(target) = map.find_discr(target) { + self.insert_idx(target, result, map); } } @@ -525,6 +540,14 @@ impl State { map.find(place).map(|place| self.get_idx(place, map)).unwrap_or(V::top()) } + /// Retrieve the value stored for a place, or ⊤ if it is not tracked. + pub fn get_discr(&self, place: PlaceRef<'_>, map: &Map) -> V { + match map.find_discr(place) { + Some(place) => self.get_idx(place, map), + None => V::top(), + } + } + /// Retrieve the value stored for a place index, or ⊤ if it is not tracked. pub fn get_idx(&self, place: PlaceIndex, map: &Map) -> V { match &self.0 { @@ -581,15 +604,15 @@ impl Map { /// This is currently the only way to create a [`Map`]. The way in which the tracked places are /// chosen is an implementation detail and may not be relied upon (other than that their type /// passes the filter). - #[instrument(skip_all, level = "debug")] pub fn from_filter<'tcx>( tcx: TyCtxt<'tcx>, body: &Body<'tcx>, filter: impl FnMut(Ty<'tcx>) -> bool, + place_limit: Option, ) -> Self { let mut map = Self::new(); let exclude = excluded_locals(body); - map.register_with_filter(tcx, body, filter, &exclude); + map.register_with_filter(tcx, body, filter, exclude, place_limit); debug!("registered {} places ({} nodes in total)", map.value_count, map.places.len()); map } @@ -600,20 +623,28 @@ impl Map { tcx: TyCtxt<'tcx>, body: &Body<'tcx>, mut filter: impl FnMut(Ty<'tcx>) -> bool, - exclude: &IndexVec, + exclude: BitSet, + place_limit: Option, ) { // We use this vector as stack, pushing and popping projections. let mut projection = Vec::new(); for (local, decl) in body.local_decls.iter_enumerated() { - if !exclude[local] { - self.register_with_filter_rec(tcx, local, &mut projection, decl.ty, &mut filter); + if !exclude.contains(local) { + self.register_with_filter_rec( + tcx, + local, + &mut projection, + decl.ty, + &mut filter, + place_limit, + ); } } } /// Potentially register the (local, projection) place and its fields, recursively. /// - /// Invariant: The projection must only contain fields. + /// Invariant: The projection must only contain trackable elements. fn register_with_filter_rec<'tcx>( &mut self, tcx: TyCtxt<'tcx>, @@ -621,27 +652,56 @@ impl Map { projection: &mut Vec>, ty: Ty<'tcx>, filter: &mut impl FnMut(Ty<'tcx>) -> bool, + place_limit: Option, ) { - // Note: The framework supports only scalars for now. - if filter(ty) && ty.is_scalar() { - // We know that the projection only contains trackable elements. - let place = self.make_place(local, projection).unwrap(); + if let Some(place_limit) = place_limit && self.value_count >= place_limit { + return + } + + // We know that the projection only contains trackable elements. + let place = self.make_place(local, projection).unwrap(); - // Allocate a value slot if it doesn't have one. - if self.places[place].value_index.is_none() { - self.places[place].value_index = Some(self.value_count.into()); - self.value_count += 1; + // Allocate a value slot if it doesn't have one, and the user requested one. + if self.places[place].value_index.is_none() && filter(ty) { + self.places[place].value_index = Some(self.value_count.into()); + self.value_count += 1; + } + + if ty.is_enum() { + let discr_ty = ty.discriminant_ty(tcx); + if filter(discr_ty) { + let discr = *self + .projections + .entry((place, TrackElem::Discriminant)) + .or_insert_with(|| { + // Prepend new child to the linked list. + let next = self.places.push(PlaceInfo::new(Some(TrackElem::Discriminant))); + self.places[next].next_sibling = self.places[place].first_child; + self.places[place].first_child = Some(next); + next + }); + + // Allocate a value slot if it doesn't have one. + if self.places[discr].value_index.is_none() { + self.places[discr].value_index = Some(self.value_count.into()); + self.value_count += 1; + } } } // Recurse with all fields of this place. iter_fields(ty, tcx, |variant, field, ty| { - if variant.is_some() { - // Downcasts are currently not supported. + if let Some(variant) = variant { + projection.push(PlaceElem::Downcast(None, variant)); + let _ = self.make_place(local, projection); + projection.push(PlaceElem::Field(field, ty)); + self.register_with_filter_rec(tcx, local, projection, ty, filter, place_limit); + projection.pop(); + projection.pop(); return; } projection.push(PlaceElem::Field(field, ty)); - self.register_with_filter_rec(tcx, local, projection, ty, filter); + self.register_with_filter_rec(tcx, local, projection, ty, filter, place_limit); projection.pop(); }); } @@ -684,23 +744,105 @@ impl Map { } /// Locates the given place, if it exists in the tree. - pub fn find(&self, place: PlaceRef<'_>) -> Option { + pub fn find_extra( + &self, + place: PlaceRef<'_>, + extra: impl IntoIterator, + ) -> Option { let mut index = *self.locals.get(place.local)?.as_ref()?; for &elem in place.projection { index = self.apply(index, elem.try_into().ok()?)?; } + for elem in extra { + index = self.apply(index, elem)?; + } Some(index) } + /// Locates the given place, if it exists in the tree. + pub fn find(&self, place: PlaceRef<'_>) -> Option { + self.find_extra(place, []) + } + + /// Locates the given place and applies `Discriminant`, if it exists in the tree. + pub fn find_discr(&self, place: PlaceRef<'_>) -> Option { + self.find_extra(place, [TrackElem::Discriminant]) + } + /// Iterate over all direct children. pub fn children(&self, parent: PlaceIndex) -> impl Iterator + '_ { Children::new(self, parent) } + /// Invoke a function on the given place and all places that may alias it. + /// + /// In particular, when the given place has a variant downcast, we invoke the function on all + /// the other variants. + /// + /// `tail_elem` allows to support discriminants that are not a place in MIR, but that we track + /// as such. + pub fn for_each_aliasing_place( + &self, + place: PlaceRef<'_>, + tail_elem: Option, + f: &mut impl FnMut(PlaceIndex), + ) { + if place.is_indirect() { + // We do not track indirect places. + return; + } + let Some(&Some(mut index)) = self.locals.get(place.local) else { + // The local is not tracked at all, so it does not alias anything. + return; + }; + let elems = place + .projection + .iter() + .map(|&elem| elem.try_into()) + .chain(tail_elem.map(Ok).into_iter()); + for elem in elems { + // A field aliases the parent place. + f(index); + + let Ok(elem) = elem else { return }; + let sub = self.apply(index, elem); + if let TrackElem::Variant(..) | TrackElem::Discriminant = elem { + // Enum variant fields and enum discriminants alias each another. + self.for_each_variant_sibling(index, sub, f); + } + if let Some(sub) = sub { + index = sub + } else { + return; + } + } + self.preorder_invoke(index, f); + } + + /// Invoke the given function on all the descendants of the given place, except one branch. + fn for_each_variant_sibling( + &self, + parent: PlaceIndex, + preserved_child: Option, + f: &mut impl FnMut(PlaceIndex), + ) { + for sibling in self.children(parent) { + let elem = self.places[sibling].proj_elem; + // Only invalidate variants and discriminant. Fields (for generators) are not + // invalidated by assignment to a variant. + if let Some(TrackElem::Variant(..) | TrackElem::Discriminant) = elem + // Only invalidate the other variants, the current one is fine. + && Some(sibling) != preserved_child + { + self.preorder_invoke(sibling, f); + } + } + } + /// Invoke a function on the given place and all descendants. - pub fn preorder_invoke(&self, root: PlaceIndex, f: &mut impl FnMut(PlaceIndex)) { + fn preorder_invoke(&self, root: PlaceIndex, f: &mut impl FnMut(PlaceIndex)) { f(root); for child in self.children(root) { self.preorder_invoke(child, f); @@ -759,6 +901,7 @@ impl<'a> Iterator for Children<'a> { } /// Used as the result of an operand or r-value. +#[derive(Debug)] pub enum ValueOrPlace { Value(V), Place(PlaceIndex), @@ -776,6 +919,8 @@ impl ValueOrPlace { #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] pub enum TrackElem { Field(Field), + Variant(VariantIdx), + Discriminant, } impl TryFrom> for TrackElem { @@ -784,6 +929,7 @@ impl TryFrom> for TrackElem { fn try_from(value: ProjectionElem) -> Result { match value { ProjectionElem::Field(field, _) => Ok(TrackElem::Field(field)), + ProjectionElem::Downcast(_, idx) => Ok(TrackElem::Variant(idx)), _ => Err(()), } } @@ -824,26 +970,27 @@ pub fn iter_fields<'tcx>( } /// Returns all locals with projections that have their reference or address taken. -pub fn excluded_locals(body: &Body<'_>) -> IndexVec { +pub fn excluded_locals(body: &Body<'_>) -> BitSet { struct Collector { - result: IndexVec, + result: BitSet, } impl<'tcx> Visitor<'tcx> for Collector { fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, _location: Location) { - if context.is_borrow() + if (context.is_borrow() || context.is_address_of() || context.is_drop() - || context == PlaceContext::MutatingUse(MutatingUseContext::AsmOutput) + || context == PlaceContext::MutatingUse(MutatingUseContext::AsmOutput)) + && !place.is_indirect() { // A pointer to a place could be used to access other places with the same local, // hence we have to exclude the local completely. - self.result[place.local] = true; + self.result.insert(place.local); } } } - let mut collector = Collector { result: IndexVec::from_elem(false, &body.local_decls) }; + let mut collector = Collector { result: BitSet::new_empty(body.local_decls.len()) }; collector.visit_body(body); collector.result } @@ -899,6 +1046,12 @@ fn debug_with_context_rec( for child in map.children(place) { let info_elem = map.places[child].proj_elem.unwrap(); let child_place_str = match info_elem { + TrackElem::Discriminant => { + format!("discriminant({})", place_str) + } + TrackElem::Variant(idx) => { + format!("({} as {:?})", place_str, idx) + } TrackElem::Field(field) => { if place_str.starts_with('*') { format!("({}).{}", place_str, field.index()) diff --git a/compiler/rustc_mir_transform/src/dataflow_const_prop.rs b/compiler/rustc_mir_transform/src/dataflow_const_prop.rs index 949a59a97bfb6..f3ca2337e59ca 100644 --- a/compiler/rustc_mir_transform/src/dataflow_const_prop.rs +++ b/compiler/rustc_mir_transform/src/dataflow_const_prop.rs @@ -13,6 +13,7 @@ use rustc_mir_dataflow::value_analysis::{Map, State, TrackElem, ValueAnalysis, V use rustc_mir_dataflow::{lattice::FlatSet, Analysis, ResultsVisitor, SwitchIntEdgeEffects}; use rustc_span::DUMMY_SP; use rustc_target::abi::Align; +use rustc_target::abi::VariantIdx; use crate::MirPass; @@ -30,14 +31,12 @@ impl<'tcx> MirPass<'tcx> for DataflowConstProp { #[instrument(skip_all level = "debug")] fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + debug!(def_id = ?body.source.def_id()); if tcx.sess.mir_opt_level() < 4 && body.basic_blocks.len() > BLOCK_LIMIT { debug!("aborted dataflow const prop due too many basic blocks"); return; } - // Decide which places to track during the analysis. - let map = Map::from_filter(tcx, body, Ty::is_scalar); - // We want to have a somewhat linear runtime w.r.t. the number of statements/terminators. // Let's call this number `n`. Dataflow analysis has `O(h*n)` transfer function // applications, where `h` is the height of the lattice. Because the height of our lattice @@ -46,10 +45,10 @@ impl<'tcx> MirPass<'tcx> for DataflowConstProp { // `O(num_nodes * tracked_places * n)` in terms of time complexity. Since the number of // map nodes is strongly correlated to the number of tracked places, this becomes more or // less `O(n)` if we place a constant limit on the number of tracked places. - if tcx.sess.mir_opt_level() < 4 && map.tracked_places() > PLACE_LIMIT { - debug!("aborted dataflow const prop due to too many tracked places"); - return; - } + let place_limit = if tcx.sess.mir_opt_level() < 4 { Some(PLACE_LIMIT) } else { None }; + + // Decide which places to track during the analysis. + let map = Map::from_filter(tcx, body, Ty::is_scalar, place_limit); // Perform the actual dataflow analysis. let analysis = ConstAnalysis::new(tcx, body, map); @@ -63,14 +62,31 @@ impl<'tcx> MirPass<'tcx> for DataflowConstProp { } } -struct ConstAnalysis<'tcx> { +struct ConstAnalysis<'a, 'tcx> { map: Map, tcx: TyCtxt<'tcx>, + local_decls: &'a LocalDecls<'tcx>, ecx: InterpCx<'tcx, 'tcx, DummyMachine>, param_env: ty::ParamEnv<'tcx>, } -impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> { +impl<'tcx> ConstAnalysis<'_, 'tcx> { + fn eval_discriminant( + &self, + enum_ty: Ty<'tcx>, + variant_index: VariantIdx, + ) -> Option> { + if !enum_ty.is_enum() { + return None; + } + let discr = enum_ty.discriminant_for_variant(self.tcx, variant_index)?; + let discr_layout = self.tcx.layout_of(self.param_env.and(discr.ty)).ok()?; + let discr_value = Scalar::try_from_uint(discr.val, discr_layout.size)?; + Some(ScalarTy(discr_value, discr.ty)) + } +} + +impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'_, 'tcx> { type Value = FlatSet>; const NAME: &'static str = "ConstAnalysis"; @@ -79,6 +95,25 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> { &self.map } + fn handle_statement(&self, statement: &Statement<'tcx>, state: &mut State) { + match statement.kind { + StatementKind::SetDiscriminant { box ref place, variant_index } => { + state.flood_discr(place.as_ref(), &self.map); + if self.map.find_discr(place.as_ref()).is_some() { + let enum_ty = place.ty(self.local_decls, self.tcx).ty; + if let Some(discr) = self.eval_discriminant(enum_ty, variant_index) { + state.assign_discr( + place.as_ref(), + ValueOrPlace::Value(FlatSet::Elem(discr)), + &self.map, + ); + } + } + } + _ => self.super_statement(statement, state), + } + } + fn handle_assign( &self, target: Place<'tcx>, @@ -87,36 +122,47 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> { ) { match rvalue { Rvalue::Aggregate(kind, operands) => { - let target = self.map().find(target.as_ref()); - if let Some(target) = target { - state.flood_idx_with(target, self.map(), FlatSet::Bottom); - let field_based = match **kind { - AggregateKind::Tuple | AggregateKind::Closure(..) => true, - AggregateKind::Adt(def_id, ..) => { - matches!(self.tcx.def_kind(def_id), DefKind::Struct) + state.flood_with(target.as_ref(), self.map(), FlatSet::Bottom); + if let Some(target_idx) = self.map().find(target.as_ref()) { + let (variant_target, variant_index) = match **kind { + AggregateKind::Tuple | AggregateKind::Closure(..) => { + (Some(target_idx), None) } - _ => false, + AggregateKind::Adt(def_id, variant_index, ..) => { + match self.tcx.def_kind(def_id) { + DefKind::Struct => (Some(target_idx), None), + DefKind::Enum => (Some(target_idx), Some(variant_index)), + _ => (None, None), + } + } + _ => (None, None), }; - if field_based { + if let Some(target) = variant_target { for (field_index, operand) in operands.iter().enumerate() { if let Some(field) = self .map() .apply(target, TrackElem::Field(Field::from_usize(field_index))) { let result = self.handle_operand(operand, state); - state.assign_idx(field, result, self.map()); + state.insert_idx(field, result, self.map()); } } } + if let Some(variant_index) = variant_index + && let Some(discr_idx) = self.map().apply(target_idx, TrackElem::Discriminant) + { + let enum_ty = target.ty(self.local_decls, self.tcx).ty; + if let Some(discr_val) = self.eval_discriminant(enum_ty, variant_index) { + state.insert_value_idx(discr_idx, FlatSet::Elem(discr_val), &self.map); + } + } } } Rvalue::CheckedBinaryOp(op, box (left, right)) => { + // Flood everything now, so we can use `insert_value_idx` directly later. + state.flood(target.as_ref(), self.map()); + let target = self.map().find(target.as_ref()); - if let Some(target) = target { - // We should not track any projections other than - // what is overwritten below, but just in case... - state.flood_idx(target, self.map()); - } let value_target = target .and_then(|target| self.map().apply(target, TrackElem::Field(0_u32.into()))); @@ -127,7 +173,8 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> { let (val, overflow) = self.binary_op(state, *op, left, right); if let Some(value_target) = value_target { - state.assign_idx(value_target, ValueOrPlace::Value(val), self.map()); + // We have flooded `target` earlier. + state.insert_value_idx(value_target, val, self.map()); } if let Some(overflow_target) = overflow_target { let overflow = match overflow { @@ -142,11 +189,8 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> { } FlatSet::Bottom => FlatSet::Bottom, }; - state.assign_idx( - overflow_target, - ValueOrPlace::Value(overflow), - self.map(), - ); + // We have flooded `target` earlier. + state.insert_value_idx(overflow_target, overflow, self.map()); } } } @@ -195,6 +239,9 @@ impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'tcx> { FlatSet::Bottom => ValueOrPlace::Value(FlatSet::Bottom), FlatSet::Top => ValueOrPlace::Value(FlatSet::Top), }, + Rvalue::Discriminant(place) => { + ValueOrPlace::Value(state.get_discr(place.as_ref(), self.map())) + } _ => self.super_rvalue(rvalue, state), } } @@ -268,12 +315,13 @@ impl<'tcx> std::fmt::Debug for ScalarTy<'tcx> { } } -impl<'tcx> ConstAnalysis<'tcx> { - pub fn new(tcx: TyCtxt<'tcx>, body: &Body<'tcx>, map: Map) -> Self { +impl<'a, 'tcx> ConstAnalysis<'a, 'tcx> { + pub fn new(tcx: TyCtxt<'tcx>, body: &'a Body<'tcx>, map: Map) -> Self { let param_env = tcx.param_env(body.source.def_id()); Self { map, tcx, + local_decls: &body.local_decls, ecx: InterpCx::new(tcx, DUMMY_SP, param_env, DummyMachine), param_env: param_env, } @@ -466,6 +514,21 @@ impl<'tcx, 'map, 'a> Visitor<'tcx> for OperandCollector<'tcx, 'map, 'a> { _ => (), } } + + fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) { + match rvalue { + Rvalue::Discriminant(place) => { + match self.state.get_discr(place.as_ref(), self.visitor.map) { + FlatSet::Top => (), + FlatSet::Elem(value) => { + self.visitor.before_effect.insert((location, *place), value); + } + FlatSet::Bottom => (), + } + } + _ => self.super_rvalue(rvalue, location), + } + } } struct DummyMachine; diff --git a/compiler/rustc_mir_transform/src/sroa.rs b/compiler/rustc_mir_transform/src/sroa.rs index 30d8511153a2b..8a37423b2a052 100644 --- a/compiler/rustc_mir_transform/src/sroa.rs +++ b/compiler/rustc_mir_transform/src/sroa.rs @@ -1,5 +1,5 @@ use crate::MirPass; -use rustc_index::bit_set::BitSet; +use rustc_index::bit_set::{BitSet, GrowableBitSet}; use rustc_index::vec::IndexVec; use rustc_middle::mir::patch::MirPatch; use rustc_middle::mir::visit::*; @@ -26,10 +26,12 @@ impl<'tcx> MirPass<'tcx> for ScalarReplacementOfAggregates { debug!(?replacements); let all_dead_locals = replace_flattened_locals(tcx, body, replacements); if !all_dead_locals.is_empty() { - for local in excluded.indices() { - excluded[local] |= all_dead_locals.contains(local); - } - excluded.raw.resize(body.local_decls.len(), false); + excluded.union(&all_dead_locals); + excluded = { + let mut growable = GrowableBitSet::from(excluded); + growable.ensure(body.local_decls.len()); + growable.into() + }; } else { break; } @@ -44,11 +46,11 @@ impl<'tcx> MirPass<'tcx> for ScalarReplacementOfAggregates { /// - the locals is a union or an enum; /// - the local's address is taken, and thus the relative addresses of the fields are observable to /// client code. -fn escaping_locals(excluded: &IndexVec, body: &Body<'_>) -> BitSet { +fn escaping_locals(excluded: &BitSet, body: &Body<'_>) -> BitSet { let mut set = BitSet::new_empty(body.local_decls.len()); set.insert_range(RETURN_PLACE..=Local::from_usize(body.arg_count)); for (local, decl) in body.local_decls().iter_enumerated() { - if decl.ty.is_union() || decl.ty.is_enum() || excluded[local] { + if decl.ty.is_union() || decl.ty.is_enum() || excluded.contains(local) { set.insert(local); } } @@ -172,7 +174,7 @@ fn replace_flattened_locals<'tcx>( body: &mut Body<'tcx>, replacements: ReplacementMap<'tcx>, ) -> BitSet { - let mut all_dead_locals = BitSet::new_empty(body.local_decls.len()); + let mut all_dead_locals = BitSet::new_empty(replacements.fragments.len()); for (local, replacements) in replacements.fragments.iter_enumerated() { if replacements.is_some() { all_dead_locals.insert(local); diff --git a/tests/mir-opt/dataflow-const-prop/enum.mutate_discriminant.DataflowConstProp.diff b/tests/mir-opt/dataflow-const-prop/enum.mutate_discriminant.DataflowConstProp.diff new file mode 100644 index 0000000000000..038e6c6bd9005 --- /dev/null +++ b/tests/mir-opt/dataflow-const-prop/enum.mutate_discriminant.DataflowConstProp.diff @@ -0,0 +1,26 @@ +- // MIR for `mutate_discriminant` before DataflowConstProp ++ // MIR for `mutate_discriminant` after DataflowConstProp + + fn mutate_discriminant() -> u8 { + let mut _0: u8; // return place in scope 0 at $DIR/enum.rs:+0:29: +0:31 + let mut _1: std::option::Option; // in scope 0 at $SRC_DIR/core/src/intrinsics/mir.rs:LL:COL + let mut _2: isize; // in scope 0 at $SRC_DIR/core/src/intrinsics/mir.rs:LL:COL + + bb0: { + discriminant(_1) = 1; // scope 0 at $DIR/enum.rs:+4:13: +4:34 + (((_1 as variant#1).0: NonZeroUsize).0: usize) = const 0_usize; // scope 0 at $DIR/enum.rs:+6:13: +6:64 + _2 = discriminant(_1); // scope 0 at $SRC_DIR/core/src/intrinsics/mir.rs:LL:COL + switchInt(_2) -> [0: bb1, otherwise: bb2]; // scope 0 at $DIR/enum.rs:+9:13: +12:14 + } + + bb1: { + _0 = const 1_u8; // scope 0 at $DIR/enum.rs:+15:13: +15:20 + return; // scope 0 at $DIR/enum.rs:+16:13: +16:21 + } + + bb2: { + _0 = const 2_u8; // scope 0 at $DIR/enum.rs:+19:13: +19:20 + unreachable; // scope 0 at $DIR/enum.rs:+20:13: +20:26 + } + } + diff --git a/tests/mir-opt/dataflow-const-prop/enum.rs b/tests/mir-opt/dataflow-const-prop/enum.rs index 13288577dea3f..7ea405bd9c408 100644 --- a/tests/mir-opt/dataflow-const-prop/enum.rs +++ b/tests/mir-opt/dataflow-const-prop/enum.rs @@ -1,13 +1,52 @@ // unit-test: DataflowConstProp -// Not trackable, because variants could be aliased. +#![feature(custom_mir, core_intrinsics, rustc_attrs)] + +use std::intrinsics::mir::*; + enum E { V1(i32), V2(i32) } -// EMIT_MIR enum.main.DataflowConstProp.diff -fn main() { +// EMIT_MIR enum.simple.DataflowConstProp.diff +fn simple() { let e = E::V1(0); let x = match e { E::V1(x) => x, E::V2(x) => x }; } + +#[rustc_layout_scalar_valid_range_start(1)] +#[rustc_nonnull_optimization_guaranteed] +struct NonZeroUsize(usize); + +// EMIT_MIR enum.mutate_discriminant.DataflowConstProp.diff +#[custom_mir(dialect = "runtime", phase = "post-cleanup")] +fn mutate_discriminant() -> u8 { + mir!( + let x: Option; + { + SetDiscriminant(x, 1); + // This assignment overwrites the niche in which the discriminant is stored. + place!(Field(Field(Variant(x, 1), 0), 0)) = 0_usize; + // So we cannot know the value of this discriminant. + let a = Discriminant(x); + match a { + 0 => bb1, + _ => bad, + } + } + bb1 = { + RET = 1; + Return() + } + bad = { + RET = 2; + Unreachable() + } + ) +} + +fn main() { + simple(); + mutate_discriminant(); +} diff --git a/tests/mir-opt/dataflow-const-prop/enum.main.DataflowConstProp.diff b/tests/mir-opt/dataflow-const-prop/enum.simple.DataflowConstProp.diff similarity index 84% rename from tests/mir-opt/dataflow-const-prop/enum.main.DataflowConstProp.diff rename to tests/mir-opt/dataflow-const-prop/enum.simple.DataflowConstProp.diff index d049c79d78def..1fb65e6584525 100644 --- a/tests/mir-opt/dataflow-const-prop/enum.main.DataflowConstProp.diff +++ b/tests/mir-opt/dataflow-const-prop/enum.simple.DataflowConstProp.diff @@ -1,8 +1,8 @@ -- // MIR for `main` before DataflowConstProp -+ // MIR for `main` after DataflowConstProp +- // MIR for `simple` before DataflowConstProp ++ // MIR for `simple` after DataflowConstProp - fn main() -> () { - let mut _0: (); // return place in scope 0 at $DIR/enum.rs:+0:11: +0:11 + fn simple() -> () { + let mut _0: (); // return place in scope 0 at $DIR/enum.rs:+0:13: +0:13 let _1: E; // in scope 0 at $DIR/enum.rs:+1:9: +1:10 let mut _3: isize; // in scope 0 at $DIR/enum.rs:+2:23: +2:31 scope 1 { @@ -25,8 +25,10 @@ StorageLive(_1); // scope 0 at $DIR/enum.rs:+1:9: +1:10 _1 = E::V1(const 0_i32); // scope 0 at $DIR/enum.rs:+1:13: +1:21 StorageLive(_2); // scope 1 at $DIR/enum.rs:+2:9: +2:10 - _3 = discriminant(_1); // scope 1 at $DIR/enum.rs:+2:19: +2:20 - switchInt(move _3) -> [0: bb3, 1: bb1, otherwise: bb2]; // scope 1 at $DIR/enum.rs:+2:13: +2:20 +- _3 = discriminant(_1); // scope 1 at $DIR/enum.rs:+2:19: +2:20 +- switchInt(move _3) -> [0: bb3, 1: bb1, otherwise: bb2]; // scope 1 at $DIR/enum.rs:+2:13: +2:20 ++ _3 = const 0_isize; // scope 1 at $DIR/enum.rs:+2:19: +2:20 ++ switchInt(const 0_isize) -> [0: bb3, 1: bb1, otherwise: bb2]; // scope 1 at $DIR/enum.rs:+2:13: +2:20 } bb1: { @@ -50,7 +52,7 @@ } bb4: { - _0 = const (); // scope 0 at $DIR/enum.rs:+0:11: +3:2 + _0 = const (); // scope 0 at $DIR/enum.rs:+0:13: +3:2 StorageDead(_2); // scope 1 at $DIR/enum.rs:+3:1: +3:2 StorageDead(_1); // scope 0 at $DIR/enum.rs:+3:1: +3:2 return; // scope 0 at $DIR/enum.rs:+3:2: +3:2