Skip to content

Commit

Permalink
JumpThreading: Bail out on interp errors
Browse files Browse the repository at this point in the history
  • Loading branch information
clubby789 committed Oct 7, 2024
1 parent 3ccfe76 commit 24db849
Showing 1 changed file with 89 additions and 65 deletions.
154 changes: 89 additions & 65 deletions compiler/rustc_mir_transform/src/jump_threading.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ impl<'tcx> crate::MirPass<'tcx> for JumpThreading {
};

for bb in body.basic_blocks.indices() {
finder.start_from_switch(bb);
if finder.start_from_switch(bb).is_none() {
return;
}
}

let opportunities = finder.opportunities;
Expand Down Expand Up @@ -170,8 +172,21 @@ impl<'a> ConditionSet<'a> {
self.iter().filter(move |c| c.matches(value))
}

fn map(self, arena: &'a DroplessArena, f: impl Fn(Condition) -> Condition) -> ConditionSet<'a> {
ConditionSet(arena.alloc_from_iter(self.iter().map(f)))
fn map(
self,
arena: &'a DroplessArena,
f: impl Fn(Condition) -> Option<Condition>,
) -> Option<ConditionSet<'a>> {
let mut all_ok = true;
let set = arena.alloc_from_iter(self.iter().map_while(|c| {
if let Some(c) = f(c) {
Some(c)
} else {
all_ok = false;
None
}
}));
all_ok.then_some(ConditionSet(set))
}
}

Expand All @@ -182,28 +197,28 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {

/// Recursion entry point to find threading opportunities.
#[instrument(level = "trace", skip(self))]
fn start_from_switch(&mut self, bb: BasicBlock) {
fn start_from_switch(&mut self, bb: BasicBlock) -> Option<()> {
let bbdata = &self.body[bb];
if bbdata.is_cleanup || self.loop_headers.contains(bb) {
return;
return Some(());
}
let Some((discr, targets)) = bbdata.terminator().kind.as_switch() else { return };
let Some(discr) = discr.place() else { return };
let Some((discr, targets)) = bbdata.terminator().kind.as_switch() else { return Some(()) };
let Some(discr) = discr.place() else { return Some(()) };
debug!(?discr, ?bb);

let discr_ty = discr.ty(self.body, self.tcx).ty;
let Ok(discr_layout) = self.ecx.layout_of(discr_ty) else {
return;
return Some(());
};

let Some(discr) = self.map.find(discr.as_ref()) else { return };
let Some(discr) = self.map.find(discr.as_ref()) else { return Some(()) };
debug!(?discr);

let cost = CostChecker::new(self.tcx, self.param_env, None, self.body);
let mut state = State::new_reachable();

let conds = if let Some((value, then, else_)) = targets.as_static_if() {
let Some(value) = ScalarInt::try_from_uint(value, discr_layout.size) else { return };
let value = ScalarInt::try_from_uint(value, discr_layout.size)?;
self.arena.alloc_from_iter([
Condition { value, polarity: Polarity::Eq, target: then },
Condition { value, polarity: Polarity::Ne, target: else_ },
Expand All @@ -217,7 +232,7 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
let conds = ConditionSet(conds);
state.insert_value_idx(discr, conds, &self.map);

self.find_opportunity(bb, state, cost, 0);
self.find_opportunity(bb, state, cost, 0)
}

/// Recursively walk statements backwards from this bb's terminator to find threading
Expand All @@ -229,27 +244,27 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
mut state: State<ConditionSet<'a>>,
mut cost: CostChecker<'_, 'tcx>,
depth: usize,
) {
) -> Option<()> {
// Do not thread through loop headers.
if self.loop_headers.contains(bb) {
return;
return Some(());
}

debug!(cost = ?cost.cost());
for (statement_index, stmt) in
self.body.basic_blocks[bb].statements.iter().enumerate().rev()
{
if self.is_empty(&state) {
return;
return Some(());
}

cost.visit_statement(stmt, Location { block: bb, statement_index });
if cost.cost() > MAX_COST {
return;
return Some(());
}

// Attempt to turn the `current_condition` on `lhs` into a condition on another place.
self.process_statement(bb, stmt, &mut state);
self.process_statement(bb, stmt, &mut state)?;

// When a statement mutates a place, assignments to that place that happen
// above the mutation cannot fulfill a condition.
Expand All @@ -261,7 +276,7 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
}

if self.is_empty(&state) || depth >= MAX_BACKTRACK {
return;
return Some(());
}

let last_non_rec = self.opportunities.len();
Expand All @@ -274,9 +289,9 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
match term.kind {
TerminatorKind::SwitchInt { ref discr, ref targets } => {
self.process_switch_int(discr, targets, bb, &mut state);
self.find_opportunity(pred, state, cost, depth + 1);
self.find_opportunity(pred, state, cost, depth + 1)?;
}
_ => self.recurse_through_terminator(pred, || state, &cost, depth),
_ => self.recurse_through_terminator(pred, || state, &cost, depth)?,
}
} else if let &[ref predecessors @ .., last_pred] = &predecessors[..] {
for &pred in predecessors {
Expand All @@ -301,12 +316,13 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
let first = &mut new_tos[0];
*first = ThreadingOpportunity { chain: vec![bb], target: first.target };
self.opportunities.truncate(last_non_rec + 1);
return;
return Some(());
}

for op in self.opportunities[last_non_rec..].iter_mut() {
op.chain.push(bb);
}
Some(())
}

/// Extract the mutated place from a statement.
Expand Down Expand Up @@ -419,23 +435,23 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
lhs: PlaceIndex,
rhs: &Operand<'tcx>,
state: &mut State<ConditionSet<'a>>,
) {
) -> Option<()> {
match rhs {
// If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
Operand::Constant(constant) => {
let Some(constant) =
self.ecx.eval_mir_constant(&constant.const_, constant.span, None).discard_err()
else {
return;
};
let constant = self
.ecx
.eval_mir_constant(&constant.const_, constant.span, None)
.discard_err()?;
self.process_constant(bb, lhs, constant, state);
}
// Transfer the conditions on the copied rhs.
Operand::Move(rhs) | Operand::Copy(rhs) => {
let Some(rhs) = self.map.find(rhs.as_ref()) else { return };
let Some(rhs) = self.map.find(rhs.as_ref()) else { return Some(()) };
state.insert_place_idx(rhs, lhs, &self.map);
}
}
Some(())
}

#[instrument(level = "trace", skip(self))]
Expand All @@ -445,22 +461,24 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
lhs_place: &Place<'tcx>,
rhs: &Rvalue<'tcx>,
state: &mut State<ConditionSet<'a>>,
) {
let Some(lhs) = self.map.find(lhs_place.as_ref()) else { return };
) -> Option<()> {
let lhs = self.map.find(lhs_place.as_ref())?;
match rhs {
Rvalue::Use(operand) => self.process_operand(bb, lhs, operand, state),
Rvalue::Use(operand) => self.process_operand(bb, lhs, operand, state)?,
// Transfer the conditions on the copy rhs.
Rvalue::CopyForDeref(rhs) => self.process_operand(bb, lhs, &Operand::Copy(*rhs), state),
Rvalue::CopyForDeref(rhs) => {
self.process_operand(bb, lhs, &Operand::Copy(*rhs), state)?
}
Rvalue::Discriminant(rhs) => {
let Some(rhs) = self.map.find_discr(rhs.as_ref()) else { return };
let Some(rhs) = self.map.find_discr(rhs.as_ref()) else { return Some(()) };
state.insert_place_idx(rhs, lhs, &self.map);
}
// If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
Rvalue::Aggregate(box ref kind, ref operands) => {
let agg_ty = lhs_place.ty(self.body, self.tcx).ty;
let lhs = match kind {
// Do not support unions.
AggregateKind::Adt(.., Some(_)) => return,
AggregateKind::Adt(.., Some(_)) => return Some(()),
AggregateKind::Adt(_, variant_index, ..) if agg_ty.is_enum() => {
if let Some(discr_target) = self.map.apply(lhs, TrackElem::Discriminant)
&& let Some(discr_value) = self
Expand All @@ -473,31 +491,31 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
if let Some(idx) = self.map.apply(lhs, TrackElem::Variant(*variant_index)) {
idx
} else {
return;
return Some(());
}
}
_ => lhs,
};
for (field_index, operand) in operands.iter_enumerated() {
if let Some(field) = self.map.apply(lhs, TrackElem::Field(field_index)) {
self.process_operand(bb, field, operand, state);
self.process_operand(bb, field, operand, state)?;
}
}
}
// Transfer the conditions on the copy rhs, after inverting the value of the condition.
Rvalue::UnaryOp(UnOp::Not, Operand::Move(place) | Operand::Copy(place)) => {
let layout = self.ecx.layout_of(place.ty(self.body, self.tcx).ty).unwrap();
let Some(conditions) = state.try_get_idx(lhs, &self.map) else { return };
let Some(place) = self.map.find(place.as_ref()) else { return };
let Some(conditions) = state.try_get_idx(lhs, &self.map) else { return Some(()) };
let Some(place) = self.map.find(place.as_ref()) else { return Some(()) };
let conds = conditions.map(self.arena, |mut cond| {
cond.value = self
.ecx
.unary_op(UnOp::Not, &ImmTy::from_scalar_int(cond.value, layout))
.unwrap()
.discard_err()?
.to_scalar_int()
.unwrap();
cond
});
.discard_err()?;
Some(cond)
})?;
state.insert_value_idx(place, conds, &self.map);
}
// We expect `lhs ?= A`. We found `lhs = Eq(rhs, B)`.
Expand All @@ -507,33 +525,36 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
box (Operand::Move(place) | Operand::Copy(place), Operand::Constant(value))
| box (Operand::Constant(value), Operand::Move(place) | Operand::Copy(place)),
) => {
let Some(conditions) = state.try_get_idx(lhs, &self.map) else { return };
let Some(place) = self.map.find(place.as_ref()) else { return };
let Some(conditions) = state.try_get_idx(lhs, &self.map) else { return Some(()) };
let Some(place) = self.map.find(place.as_ref()) else { return Some(()) };
let equals = match op {
BinOp::Eq => ScalarInt::TRUE,
BinOp::Ne => ScalarInt::FALSE,
_ => return,
_ => return Some(()),
};
if value.const_.ty().is_floating_point() {
// Floating point equality does not follow bit-patterns.
// -0.0 and NaN both have special rules for equality,
// and therefore we cannot use integer comparisons for them.
// Avoid handling them, though this could be extended in the future.
return;
return Some(());
}
let Some(value) = value.const_.try_eval_scalar_int(self.tcx, self.param_env) else {
return;
return Some(());
};
let conds = conditions.map(self.arena, |c| Condition {
value,
polarity: if c.matches(equals) { Polarity::Eq } else { Polarity::Ne },
..c
});
let conds = conditions.map(self.arena, |c| {
Some(Condition {
value,
polarity: if c.matches(equals) { Polarity::Eq } else { Polarity::Ne },
..c
})
})?;
state.insert_value_idx(place, conds, &self.map);
}

_ => {}
}
Some(())
}

#[instrument(level = "trace", skip(self))]
Expand All @@ -542,7 +563,7 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
bb: BasicBlock,
stmt: &Statement<'tcx>,
state: &mut State<ConditionSet<'a>>,
) {
) -> Option<()> {
let register_opportunity = |c: Condition| {
debug!(?bb, ?c.target, "register");
self.opportunities.push(ThreadingOpportunity { chain: vec![bb], target: c.target })
Expand All @@ -555,13 +576,15 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
// If we expect `discriminant(place) ?= A`,
// we have an opportunity if `variant_index ?= A`.
StatementKind::SetDiscriminant { box place, variant_index } => {
let Some(discr_target) = self.map.find_discr(place.as_ref()) else { return };
let Some(discr_target) = self.map.find_discr(place.as_ref()) else {
return Some(());
};
let enum_ty = place.ty(self.body, self.tcx).ty;
// `SetDiscriminant` may be a no-op if the assigned variant is the untagged variant
// of a niche encoding. If we cannot ensure that we write to the discriminant, do
// nothing.
let Ok(enum_layout) = self.ecx.layout_of(enum_ty) else {
return;
return Some(());
};
let writes_discriminant = match enum_layout.variants {
Variants::Single { index } => {
Expand All @@ -575,26 +598,26 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
} => *variant_index != untagged_variant,
};
if writes_discriminant {
let Some(discr) =
self.ecx.discriminant_for_variant(enum_ty, *variant_index).discard_err()
else {
return;
};
let discr =
self.ecx.discriminant_for_variant(enum_ty, *variant_index).discard_err()?;
self.process_immediate(bb, discr_target, discr, state);
}
}
// If we expect `lhs ?= true`, we have an opportunity if we assume `lhs == true`.
StatementKind::Intrinsic(box NonDivergingIntrinsic::Assume(
Operand::Copy(place) | Operand::Move(place),
)) => {
let Some(conditions) = state.try_get(place.as_ref(), &self.map) else { return };
let Some(conditions) = state.try_get(place.as_ref(), &self.map) else {
return Some(());
};
conditions.iter_matches(ScalarInt::TRUE).for_each(register_opportunity);
}
StatementKind::Assign(box (lhs_place, rhs)) => {
self.process_assign(bb, lhs_place, rhs, state);
self.process_assign(bb, lhs_place, rhs, state)?;
}
_ => {}
}
Some(())
}

#[instrument(level = "trace", skip(self, state, cost))]
Expand All @@ -605,7 +628,7 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
state: impl FnOnce() -> State<ConditionSet<'a>>,
cost: &CostChecker<'_, 'tcx>,
depth: usize,
) {
) -> Option<()> {
let term = self.body.basic_blocks[bb].terminator();
let place_to_flood = match term.kind {
// We come from a target, so those are not possible.
Expand All @@ -620,9 +643,9 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
| TerminatorKind::FalseUnwind { .. }
| TerminatorKind::Yield { .. } => bug!("{term:?} invalid"),
// Cannot reason about inline asm.
TerminatorKind::InlineAsm { .. } => return,
TerminatorKind::InlineAsm { .. } => return Some(()),
// `SwitchInt` is handled specially.
TerminatorKind::SwitchInt { .. } => return,
TerminatorKind::SwitchInt { .. } => return Some(()),
// We can recurse, no thing particular to do.
TerminatorKind::Goto { .. } => None,
// Flood the overwritten place, and progress through.
Expand All @@ -637,7 +660,8 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
if let Some(place_to_flood) = place_to_flood {
state.flood_with(place_to_flood.as_ref(), &self.map, ConditionSet::BOTTOM);
}
self.find_opportunity(bb, state, cost.clone(), depth + 1);
self.find_opportunity(bb, state, cost.clone(), depth + 1)?;
Some(())
}

#[instrument(level = "trace", skip(self))]
Expand Down

0 comments on commit 24db849

Please sign in to comment.