Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use an interpreter in MIR jump threading #119461

Merged
merged 3 commits into from
Jan 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
262 changes: 155 additions & 107 deletions compiler/rustc_mir_transform/src/jump_threading.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,21 @@
//! cost by `MAX_COST`.

use rustc_arena::DroplessArena;
use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, Projectable};
use rustc_data_structures::fx::FxHashSet;
use rustc_index::bit_set::BitSet;
use rustc_index::IndexVec;
use rustc_middle::mir::interpret::Scalar;
use rustc_middle::mir::visit::Visitor;
use rustc_middle::mir::*;
use rustc_middle::ty::{self, ScalarInt, Ty, TyCtxt};
use rustc_middle::ty::layout::LayoutOf;
use rustc_middle::ty::{self, ScalarInt, TyCtxt};
use rustc_mir_dataflow::value_analysis::{Map, PlaceIndex, State, TrackElem};
use rustc_span::DUMMY_SP;
use rustc_target::abi::{TagEncoding, Variants};

use crate::cost_checker::CostChecker;
use crate::dataflow_const_prop::DummyMachine;

pub struct JumpThreading;

Expand All @@ -71,6 +76,7 @@ impl<'tcx> MirPass<'tcx> for JumpThreading {
let mut finder = TOFinder {
tcx,
param_env,
ecx: InterpCx::new(tcx, DUMMY_SP, param_env, DummyMachine),
body,
arena: &arena,
map: &map,
Expand All @@ -88,7 +94,7 @@ impl<'tcx> MirPass<'tcx> for JumpThreading {
debug!(?discr, ?bb);

let discr_ty = discr.ty(body, tcx).ty;
let Ok(discr_layout) = tcx.layout_of(param_env.and(discr_ty)) else { continue };
let Ok(discr_layout) = finder.ecx.layout_of(discr_ty) else { continue };

let Some(discr) = finder.map.find(discr.as_ref()) else { continue };
debug!(?discr);
Expand Down Expand Up @@ -142,6 +148,7 @@ struct ThreadingOpportunity {
struct TOFinder<'tcx, 'a> {
tcx: TyCtxt<'tcx>,
param_env: ty::ParamEnv<'tcx>,
ecx: InterpCx<'tcx, 'tcx, DummyMachine>,
body: &'a Body<'tcx>,
map: &'a Map,
loop_headers: &'a BitSet<BasicBlock>,
Expand Down Expand Up @@ -329,25 +336,82 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
}

#[instrument(level = "trace", skip(self))]
fn process_operand(
fn process_immediate(
&mut self,
bb: BasicBlock,
lhs: PlaceIndex,
rhs: &Operand<'tcx>,
rhs: ImmTy<'tcx>,
state: &mut State<ConditionSet<'a>>,
) -> Option<!> {
let register_opportunity = |c: Condition| {
debug!(?bb, ?c.target, "register");
self.opportunities.push(ThreadingOpportunity { chain: vec![bb], target: c.target })
};

let conditions = state.try_get_idx(lhs, self.map)?;
if let Immediate::Scalar(Scalar::Int(int)) = *rhs {
conditions.iter_matches(int).for_each(register_opportunity);
}

None
}

/// If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
#[instrument(level = "trace", skip(self))]
fn process_constant(
&mut self,
bb: BasicBlock,
lhs: PlaceIndex,
constant: OpTy<'tcx>,
state: &mut State<ConditionSet<'a>>,
) {
self.map.for_each_projection_value(
lhs,
constant,
&mut |elem, op| match elem {
TrackElem::Field(idx) => self.ecx.project_field(op, idx.as_usize()).ok(),
TrackElem::Variant(idx) => self.ecx.project_downcast(op, idx).ok(),
TrackElem::Discriminant => {
let variant = self.ecx.read_discriminant(op).ok()?;
let discr_value =
self.ecx.discriminant_for_variant(op.layout.ty, variant).ok()?;
Some(discr_value.into())
}
TrackElem::DerefLen => {
let op: OpTy<'_> = self.ecx.deref_pointer(op).ok()?.into();
let len_usize = op.len(&self.ecx).ok()?;
let layout = self.ecx.layout_of(self.tcx.types.usize).unwrap();
Some(ImmTy::from_uint(len_usize, layout).into())
}
},
&mut |place, op| {
if let Some(conditions) = state.try_get_idx(place, self.map)
&& let Ok(imm) = self.ecx.read_immediate_raw(op)
&& let Some(imm) = imm.right()
&& let Immediate::Scalar(Scalar::Int(int)) = *imm
{
conditions.iter_matches(int).for_each(|c: Condition| {
self.opportunities
.push(ThreadingOpportunity { chain: vec![bb], target: c.target })
})
}
},
);
}

#[instrument(level = "trace", skip(self))]
fn process_operand(
&mut self,
bb: BasicBlock,
lhs: PlaceIndex,
rhs: &Operand<'tcx>,
state: &mut State<ConditionSet<'a>>,
) -> Option<!> {
match rhs {
// If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
Operand::Constant(constant) => {
let conditions = state.try_get_idx(lhs, self.map)?;
let constant =
constant.const_.normalize(self.tcx, self.param_env).try_to_scalar_int()?;
conditions.iter_matches(constant).for_each(register_opportunity);
let constant = self.ecx.eval_mir_constant(&constant.const_, None, None).ok()?;
self.process_constant(bb, lhs, constant, state);
}
// Transfer the conditions on the copied rhs.
Operand::Move(rhs) | Operand::Copy(rhs) => {
Expand All @@ -359,6 +423,84 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
None
}

#[instrument(level = "trace", skip(self))]
fn process_assign(
&mut self,
bb: BasicBlock,
lhs_place: &Place<'tcx>,
rhs: &Rvalue<'tcx>,
state: &mut State<ConditionSet<'a>>,
) -> Option<!> {
let lhs = self.map.find(lhs_place.as_ref())?;
match rhs {
Rvalue::Use(operand) => self.process_operand(bb, lhs, operand, state)?,
// Transfer the conditions on the copy rhs.
Rvalue::CopyForDeref(rhs) => {
self.process_operand(bb, lhs, &Operand::Copy(*rhs), state)?
}
Rvalue::Discriminant(rhs) => {
let rhs = self.map.find_discr(rhs.as_ref())?;
state.insert_place_idx(rhs, lhs, self.map);
}
// If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
Rvalue::Aggregate(box ref kind, ref operands) => {
let agg_ty = lhs_place.ty(self.body, self.tcx).ty;
let lhs = match kind {
// Do not support unions.
AggregateKind::Adt(.., Some(_)) => return None,
AggregateKind::Adt(_, variant_index, ..) if agg_ty.is_enum() => {
if let Some(discr_target) = self.map.apply(lhs, TrackElem::Discriminant)
&& let Ok(discr_value) =
self.ecx.discriminant_for_variant(agg_ty, *variant_index)
{
self.process_immediate(bb, discr_target, discr_value, state);
}
self.map.apply(lhs, TrackElem::Variant(*variant_index))?
}
_ => lhs,
};
for (field_index, operand) in operands.iter_enumerated() {
if let Some(field) = self.map.apply(lhs, TrackElem::Field(field_index)) {
self.process_operand(bb, field, operand, state);
}
}
}
// Transfer the conditions on the copy rhs, after inversing polarity.
Rvalue::UnaryOp(UnOp::Not, Operand::Move(place) | Operand::Copy(place)) => {
let conditions = state.try_get_idx(lhs, self.map)?;
let place = self.map.find(place.as_ref())?;
let conds = conditions.map(self.arena, Condition::inv);
state.insert_value_idx(place, conds, self.map);
}
// We expect `lhs ?= A`. We found `lhs = Eq(rhs, B)`.
// Create a condition on `rhs ?= B`.
Rvalue::BinaryOp(
op,
box (Operand::Move(place) | Operand::Copy(place), Operand::Constant(value))
| box (Operand::Constant(value), Operand::Move(place) | Operand::Copy(place)),
) => {
let conditions = state.try_get_idx(lhs, self.map)?;
let place = self.map.find(place.as_ref())?;
let equals = match op {
BinOp::Eq => ScalarInt::TRUE,
BinOp::Ne => ScalarInt::FALSE,
_ => return None,
};
let value = value.const_.normalize(self.tcx, self.param_env).try_to_scalar_int()?;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes little sense. If you want to evaluate the constant, call eval, not normalize. normalize is only meant for when you want to keep this around as a mir::Const but for some reason you want to be sure that if it can evaluated, it is a value (but if it cannot be evaluated that's okay, too). It's a pretty strange operation that should be rarely needed.

let conds = conditions.map(self.arena, |c| Condition {
value,
polarity: if c.matches(equals) { Polarity::Eq } else { Polarity::Ne },
..c
});
state.insert_value_idx(place, conds, self.map);
}

_ => {}
}

None
}

#[instrument(level = "trace", skip(self))]
fn process_statement(
&mut self,
Expand All @@ -374,18 +516,6 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
// Below, `lhs` is the return value of `mutated_statement`,
// the place to which `conditions` apply.

let discriminant_for_variant = |enum_ty: Ty<'tcx>, variant_index| {
let discr = enum_ty.discriminant_for_variant(self.tcx, variant_index)?;
let discr_layout = self.tcx.layout_of(self.param_env.and(discr.ty)).ok()?;
let scalar = ScalarInt::try_from_uint(discr.val, discr_layout.size)?;
Some(Operand::const_from_scalar(
self.tcx,
discr.ty,
scalar.into(),
rustc_span::DUMMY_SP,
))
};

match &stmt.kind {
// If we expect `discriminant(place) ?= A`,
// we have an opportunity if `variant_index ?= A`.
Expand All @@ -395,7 +525,7 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
// `SetDiscriminant` may be a no-op if the assigned variant is the untagged variant
// of a niche encoding. If we cannot ensure that we write to the discriminant, do
// nothing.
let enum_layout = self.tcx.layout_of(self.param_env.and(enum_ty)).ok()?;
let enum_layout = self.ecx.layout_of(enum_ty).ok()?;
let writes_discriminant = match enum_layout.variants {
Variants::Single { index } => {
assert_eq!(index, *variant_index);
Expand All @@ -408,8 +538,8 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
} => *variant_index != untagged_variant,
};
if writes_discriminant {
let discr = discriminant_for_variant(enum_ty, *variant_index)?;
self.process_operand(bb, discr_target, &discr, state)?;
let discr = self.ecx.discriminant_for_variant(enum_ty, *variant_index).ok()?;
self.process_immediate(bb, discr_target, discr, state)?;
}
}
// If we expect `lhs ?= true`, we have an opportunity if we assume `lhs == true`.
Expand All @@ -420,89 +550,7 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
conditions.iter_matches(ScalarInt::TRUE).for_each(register_opportunity);
}
StatementKind::Assign(box (lhs_place, rhs)) => {
if let Some(lhs) = self.map.find(lhs_place.as_ref()) {
match rhs {
Rvalue::Use(operand) => self.process_operand(bb, lhs, operand, state)?,
// Transfer the conditions on the copy rhs.
Rvalue::CopyForDeref(rhs) => {
self.process_operand(bb, lhs, &Operand::Copy(*rhs), state)?
}
Rvalue::Discriminant(rhs) => {
let rhs = self.map.find_discr(rhs.as_ref())?;
state.insert_place_idx(rhs, lhs, self.map);
}
// If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
Rvalue::Aggregate(box ref kind, ref operands) => {
let agg_ty = lhs_place.ty(self.body, self.tcx).ty;
let lhs = match kind {
// Do not support unions.
AggregateKind::Adt(.., Some(_)) => return None,
AggregateKind::Adt(_, variant_index, ..) if agg_ty.is_enum() => {
if let Some(discr_target) =
self.map.apply(lhs, TrackElem::Discriminant)
&& let Some(discr_value) =
discriminant_for_variant(agg_ty, *variant_index)
{
self.process_operand(bb, discr_target, &discr_value, state);
}
self.map.apply(lhs, TrackElem::Variant(*variant_index))?
}
_ => lhs,
};
for (field_index, operand) in operands.iter_enumerated() {
if let Some(field) =
self.map.apply(lhs, TrackElem::Field(field_index))
{
self.process_operand(bb, field, operand, state);
}
}
}
// Transfer the conditions on the copy rhs, after inversing polarity.
Rvalue::UnaryOp(UnOp::Not, Operand::Move(place) | Operand::Copy(place)) => {
let conditions = state.try_get_idx(lhs, self.map)?;
let place = self.map.find(place.as_ref())?;
let conds = conditions.map(self.arena, Condition::inv);
state.insert_value_idx(place, conds, self.map);
}
// We expect `lhs ?= A`. We found `lhs = Eq(rhs, B)`.
// Create a condition on `rhs ?= B`.
Rvalue::BinaryOp(
op,
box (
Operand::Move(place) | Operand::Copy(place),
Operand::Constant(value),
)
| box (
Operand::Constant(value),
Operand::Move(place) | Operand::Copy(place),
),
) => {
let conditions = state.try_get_idx(lhs, self.map)?;
let place = self.map.find(place.as_ref())?;
let equals = match op {
BinOp::Eq => ScalarInt::TRUE,
BinOp::Ne => ScalarInt::FALSE,
_ => return None,
};
let value = value
.const_
.normalize(self.tcx, self.param_env)
.try_to_scalar_int()?;
let conds = conditions.map(self.arena, |c| Condition {
value,
polarity: if c.matches(equals) {
Polarity::Eq
} else {
Polarity::Ne
},
..c
});
state.insert_value_idx(place, conds, self.map);
}

_ => {}
}
}
self.process_assign(bb, lhs_place, rhs, state)?;
}
_ => {}
}
Expand Down Expand Up @@ -577,7 +625,7 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {

let discr = discr.place()?;
let discr_ty = discr.ty(self.body, self.tcx).ty;
let discr_layout = self.tcx.layout_of(self.param_env.and(discr_ty)).ok()?;
let discr_layout = self.ecx.layout_of(discr_ty).ok()?;
let conditions = state.try_get(discr.as_ref(), self.map)?;

if let Some((value, _)) = targets.iter().find(|&(_, target)| target == target_bb) {
Expand Down
Loading
Loading