Skip to content

Commit

Permalink
analyze: discard rewrites on bounds checking code (#1177)
Browse files Browse the repository at this point in the history
This branch adds logic for identifying the bounds-checking code that
rustc automatically inserts around array and slice accesses and for
discarding rewrites that would normally be generated for that code.
Trying to rewrite this code is pointless because the code is derived
automatically from other parts of the function, and it can trigger
errors due to the rewriter not having a distinct place in the source
code where it can apply the rewrites.

In the future, we could potentially extend this to skip rewrites on
overflow assertions as well.
  • Loading branch information
spernsteiner authored Dec 6, 2024
2 parents 43fb2f8 + 6b266a9 commit b069bb6
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 36 deletions.
11 changes: 8 additions & 3 deletions c2rust-analyze/src/rewrite/expr/distribute.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use crate::rewrite::expr::mir_op::{self, MirRewrite};
use crate::rewrite::expr::unlower::{MirOrigin, MirOriginDesc, PreciseLoc};
use crate::rewrite::expr::unlower::{MirOriginDesc, PreciseLoc, UnlowerMap};
use itertools::Itertools;
use log::*;
use rustc_hir::HirId;
use rustc_middle::mir::Location;
use rustc_middle::ty::TyCtxt;
use std::cmp::Ordering;
use std::collections::{BTreeMap, HashMap};
use std::collections::HashMap;

struct RewriteInfo {
rw: mir_op::RewriteKind,
Expand Down Expand Up @@ -120,12 +120,17 @@ impl From<RewriteInfo> for DistRewrite {
/// afterward.
pub fn distribute(
tcx: TyCtxt,
unlower_map: BTreeMap<PreciseLoc, MirOrigin>,
unlower_map: UnlowerMap,
mir_rewrites: HashMap<Location, Vec<MirRewrite>>,
) -> HashMap<HirId, Vec<DistRewrite>> {
let mut info_map = HashMap::<HirId, Vec<RewriteInfo>>::new();

for (loc, mir_rws) in mir_rewrites {
if unlower_map.discard_rewrites_for(loc) {
trace!("discarding {} rewrites for {loc:?}", mir_rws.len());
continue;
}

for mir_rw in mir_rws {
let key = PreciseLoc {
loc,
Expand Down
24 changes: 20 additions & 4 deletions c2rust-analyze/src/rewrite/expr/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use self::mir_op::MirRewrite;
use self::unlower::{MirOrigin, PreciseLoc};
use self::unlower::{PreciseLoc, UnlowerMap};
use crate::context::{AnalysisCtxt, Assignment};
use crate::pointee_type::PointeeTypes;
use crate::pointer_id::PointerTable;
Expand All @@ -9,7 +9,7 @@ use rustc_hir::BodyId;
use rustc_middle::mir::{Body, Location};
use rustc_middle::ty::TyCtxt;
use rustc_span::Span;
use std::collections::{BTreeMap, HashMap};
use std::collections::HashMap;

mod convert;
mod distribute;
Expand Down Expand Up @@ -57,7 +57,7 @@ pub fn gen_expr_rewrites<'tcx>(
fn debug_print_unlower_map<'tcx>(
tcx: TyCtxt<'tcx>,
mir: &Body<'tcx>,
unlower_map: &BTreeMap<PreciseLoc, MirOrigin>,
unlower_map: &UnlowerMap,
mir_rewrites: &HashMap<Location, Vec<MirRewrite>>,
) {
let print_for_loc = |loc| {
Expand All @@ -69,7 +69,15 @@ fn debug_print_unlower_map<'tcx>(
.push(&rw.kind);
}

for (k, v) in unlower_map.range(&PreciseLoc { loc, sub: vec![] }..) {
if unlower_map.discard_rewrites_for(loc) {
eprintln!(" DISCARD all rewrites for this location");
}

let mut found_at_least_one_origin = false;
for (k, v) in unlower_map
.origins_map()
.range(&PreciseLoc { loc, sub: vec![] }..)
{
if k.loc != loc {
break;
}
Expand All @@ -79,6 +87,14 @@ fn debug_print_unlower_map<'tcx>(
for rw_kind in rewrites_by_subloc.remove(&sublocs).unwrap_or_default() {
eprintln!(" {rw_kind:?}");
}
found_at_least_one_origin = true;
}

if !found_at_least_one_origin {
let span = mir
.stmt_at(loc)
.either(|s| s.source_info.span, |t| t.source_info.span);
eprintln!(" {span:?} (no unlowering entries found)");
}

for (sublocs, rw_kinds) in rewrites_by_subloc {
Expand Down
176 changes: 147 additions & 29 deletions c2rust-analyze/src/rewrite/expr/unlower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use rustc_middle::ty::adjustment::{Adjust, AutoBorrow, AutoBorrowMutability, Poi
use rustc_middle::ty::{TyCtxt, TypeckResults};
use rustc_span::Span;
use std::collections::btree_map::{BTreeMap, Entry};
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};

#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash)]
pub struct PreciseLoc {
Expand Down Expand Up @@ -66,13 +66,37 @@ pub enum MirOriginDesc {
LoadFromTempForAdjustment(usize),
}

#[derive(Clone, Debug, Default)]
pub struct UnlowerMap {
/// Maps MIR (sub)locations to the HIR node that produced each one, if known.
origins: BTreeMap<PreciseLoc, MirOrigin>,
/// MIR locations for which we discard all rewrites. This is used for "derived" statements,
/// such as the `Len` + `Lt` + `Assert` operations that make up an array bounds check, where
/// rewrites applied to the main statement will automatically affect the derived statements
/// when the rewritten code is compiled.
discard: HashSet<Location>,
}

impl UnlowerMap {
pub fn get(&self, key: &PreciseLoc) -> Option<&MirOrigin> {
self.origins.get(key)
}

pub fn origins_map(&self) -> &BTreeMap<PreciseLoc, MirOrigin> {
&self.origins
}

pub fn discard_rewrites_for(&self, loc: Location) -> bool {
self.discard.contains(&loc)
}
}

struct UnlowerVisitor<'a, 'tcx> {
tcx: TyCtxt<'tcx>,
mir: &'a Body<'tcx>,
typeck_results: &'tcx TypeckResults<'tcx>,
span_index: SpanIndex<Location>,
/// Maps MIR (sub)locations to the HIR node that produced each one, if known.
unlower_map: BTreeMap<PreciseLoc, MirOrigin>,
unlower_map: UnlowerMap,

/// When processing the `hir::Expr` identified by the `HirId`, append some locations to the
/// list retrieved from the `SpanIndex`. This is used in cases where some MIR statements have
Expand Down Expand Up @@ -113,7 +137,7 @@ impl<'a, 'tcx> UnlowerVisitor<'a, 'tcx> {
loc,
sub: sub_loc.to_owned(),
};
match self.unlower_map.entry(key) {
match self.unlower_map.origins.entry(key) {
Entry::Vacant(e) => {
e.insert(origin);
}
Expand Down Expand Up @@ -212,15 +236,14 @@ impl<'a, 'tcx> UnlowerVisitor<'a, 'tcx> {
.filter(|&loc| !self.should_ignore_statement(loc)),
);
}
let locs = locs;
if locs.is_empty() {
return;
}

let warn = |desc| {
let warn = |locs: &[_], desc| {
warn!("{}", desc);
info!("locs:");
for &loc in &locs {
for &loc in locs {
self.mir.stmt_at(loc).either(
|stmt| info!(" {:?}: {:?}", locs, stmt),
|term| info!(" {:?}: {:?}", locs, term),
Expand All @@ -237,21 +260,22 @@ impl<'a, 'tcx> UnlowerVisitor<'a, 'tcx> {
let (loc, mir_pl, mir_rv) = match self.get_sole_assign(&locs) {
Some(x) => x,
None => {
warn("expected exactly one StatementKind::Assign");
warn(&locs, "expected exactly one StatementKind::Assign");
return;
}
};
self.record(loc, &[], ex);
self.visit_expr_place(pl, loc, vec![SubLoc::Dest], mir_pl, &[]);
self.visit_expr_rvalue(rv, loc, vec![SubLoc::Rvalue], mir_rv, &[]);
return;
}

hir::ExprKind::Call(_, args) | hir::ExprKind::MethodCall(_, args, _) => {
// Handle adjustments on the call's output first.
let (_mir_pl, mut cursor) = match self.make_visit_expr_cursor(&locs) {
Some(x @ (pl, _)) if is_var(pl) => x,
_ => {
warn("expected final Assign to store into var");
warn(&locs, "expected final Assign to store into var");
debug!(
"visit_expr_inner: bail out: expr at {:?} isn't assigned to a var",
ex.span
Expand All @@ -272,15 +296,15 @@ impl<'a, 'tcx> UnlowerVisitor<'a, 'tcx> {
..
} => {
if !is_var(destination) {
warn("expected final Call to store into var");
warn(&locs, "expected final Call to store into var");
return;
}
args
}
_ => unreachable!("ExprMir::Call should always contain Call terminator"),
},
_ => {
warn("expected MIR Call for HIR Call/MethodCall");
warn(&locs, "expected MIR Call for HIR Call/MethodCall");
return;
}
};
Expand Down Expand Up @@ -309,23 +333,121 @@ impl<'a, 'tcx> UnlowerVisitor<'a, 'tcx> {
warn!("NYI: extra locations {:?} in Call", &locs[..locs.len() - 1]);
}
}
return;
}

_ => {
// For all other `ExprKind`s, we expect the last `loc` to be an assignment storing
// the final result into a temporary.
let (_mir_pl, mut cursor) = match self.make_visit_expr_cursor(&locs) {
Some(x @ (pl, _)) if is_var(pl) => x,
_ => {
warn("expected final Assign to store into var");
return;
// Remaining cases fall through to the default behavior below.
hir::ExprKind::Index(_arr_ex, _idx_ex) => {
// Look for the following pattern:
// _3 = Len(((*_1).0: [u32; 4]))
// _4 = Lt(_2, _3)
// assert(
// move _4,
// "index out of bounds: the length is {} but the index is {}",
// move _3,
// _2,
// ) -> [success: bb1, unwind: bb2];
let mut match_pattern = || -> Option<()> {
let mut iter = locs.iter().enumerate();
// This pattern of `iter.by_ref().filter_map(..).next()` advances `iter` until
// it produces an item matching the `filter_map` predicate. The next call with
// this pattern will continue searching from the following item.
let (len_idx, len_var) = iter
.by_ref()
.filter_map(|(i, &loc)| {
// Look for `_len = Len(_)`
let stmt = self.mir.stmt_at(loc).left()?;
if let mir::StatementKind::Assign(ref x) = stmt.kind {
let (ref pl, ref rv) = **x;
let pl_var = pl.as_local()?;
if matches!(rv, mir::Rvalue::Len(_)) {
return Some((i, pl_var));
}
}
None
})
.next()?;
let (lt_idx, lt_var) = iter
.by_ref()
.filter_map(|(i, &loc)| {
// Look for `_ok = Lt(_, _len)`
let stmt = self.mir.stmt_at(loc).left()?;
if let mir::StatementKind::Assign(ref x) = stmt.kind {
let (ref pl, ref rv) = **x;
let pl_var = pl.as_local()?;
if let mir::Rvalue::BinaryOp(mir::BinOp::Lt, ref ops) = *rv {
let (_, ref op2) = **ops;
let op2_var = op2.place()?.as_local()?;
if op2_var == len_var {
return Some((i, pl_var));
}
}
}
None
})
.next()?;
let assert_idx = iter
.by_ref()
.filter_map(|(i, &loc)| {
// Look for `Assert(_ok, ..)`
let term = self.mir.stmt_at(loc).right()?;
if let mir::TerminatorKind::Assert { ref cond, .. } = term.kind {
let cond_var = cond.place()?.as_local()?;
if cond_var == lt_var {
return Some(i);
}
}
None
})
.next()?;

// All three parts were found. Mark them as `discard`, then remove them from
// `locs`.
self.unlower_map.discard.insert(locs[len_idx]);
self.unlower_map.discard.insert(locs[lt_idx]);
self.unlower_map.discard.insert(locs[assert_idx]);

if lt_idx == len_idx + 1 && assert_idx == len_idx + 2 {
// All three locations are consecutive. Remove them with `drain`.
locs.drain(lt_idx..=assert_idx);
} else {
// Remove the three locations separately. Remove in reverse order to avoid
// perturbing the other indices.
debug_assert!(assert_idx > lt_idx);
debug_assert!(lt_idx > len_idx);
locs.remove(assert_idx);
locs.remove(lt_idx);
locs.remove(len_idx);
}

Some(())
};
self.record_desc(cursor.loc, &[], ex, MirOriginDesc::StoreIntoLocal);
self.walk_expr(ex, &mut cursor);
self.finish_visit_expr_cursor(ex, cursor);
// `match_pattern` returns `Option` only so we can bail out with `?`. The result
// is unused.
let _ = match_pattern();
}

_ => {}
}

// `locs` can become empty if some locations were found, but all of them were consumed by
// earlier processing.
if locs.is_empty() {
return;
}

// For all other `ExprKind`s, we expect the last `loc` to be an assignment storing the
// final result into a temporary.
let (_mir_pl, mut cursor) = match self.make_visit_expr_cursor(&locs) {
Some(x @ (pl, _)) if is_var(pl) => x,
_ => {
warn(&locs, "expected final Assign to store into var");
return;
}
};
self.record_desc(cursor.loc, &[], ex, MirOriginDesc::StoreIntoLocal);
self.walk_expr(ex, &mut cursor);
self.finish_visit_expr_cursor(ex, cursor);
}

/// Try to create a `VisitExprCursor` from the RHS of statement `locs.last()`. Returns the LHS
Expand Down Expand Up @@ -1033,14 +1155,10 @@ fn build_span_index(mir: &Body<'_>) -> SpanIndex<Location> {
/// This function returns a `BTreeMap`, which supports iterating in sorted order. This allows
/// looking up entries by a prefix of their key (for example, finding all entries on a given
/// `Location` regardless of their `SubLoc`s) using the `BTreeMap::range` method.
pub fn unlower<'tcx>(
tcx: TyCtxt<'tcx>,
mir: &Body<'tcx>,
hir_body_id: hir::BodyId,
) -> BTreeMap<PreciseLoc, MirOrigin> {
pub fn unlower<'tcx>(tcx: TyCtxt<'tcx>, mir: &Body<'tcx>, hir_body_id: hir::BodyId) -> UnlowerMap {
// If this MIR body came from a `#[derive]`, ignore it.
if util::is_automatically_derived(tcx, mir) {
return BTreeMap::new();
return UnlowerMap::default();
}

let typeck_results = tcx.typeck_body(hir_body_id);
Expand All @@ -1056,7 +1174,7 @@ pub fn unlower<'tcx>(
mir,
typeck_results,
span_index,
unlower_map: BTreeMap::new(),
unlower_map: UnlowerMap::default(),
append_extra_locations: HashMap::new(),
};
visitor.visit_body(hir);
Expand Down

0 comments on commit b069bb6

Please sign in to comment.