forked from rust-lang/rust
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Rollup merge of rust-lang#107172 - cjgillot:no-nal, r=nagisa
Reimplement NormalizeArrayLen based on SsaLocals Based on rust-lang#106908 Fixes rust-lang#105929 Only the last commit "Reimplement NormalizeArrayLen" is relevant.
- Loading branch information
Showing
17 changed files
with
191 additions
and
436 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
322 changes: 69 additions & 253 deletions
322
compiler/rustc_mir_transform/src/normalize_array_len.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,288 +1,104 @@ | ||
//! This pass eliminates casting of arrays into slices when their length | ||
//! is taken using `.len()` method. Handy to preserve information in MIR for const prop | ||
|
||
use crate::ssa::SsaLocals; | ||
use crate::MirPass; | ||
use rustc_data_structures::fx::FxIndexMap; | ||
use rustc_data_structures::intern::Interned; | ||
use rustc_index::bit_set::BitSet; | ||
use rustc_index::vec::IndexVec; | ||
use rustc_middle::mir::visit::*; | ||
use rustc_middle::mir::*; | ||
use rustc_middle::ty::{self, ReErased, Region, TyCtxt}; | ||
|
||
const MAX_NUM_BLOCKS: usize = 800; | ||
const MAX_NUM_LOCALS: usize = 3000; | ||
use rustc_middle::ty::{self, TyCtxt}; | ||
use rustc_mir_dataflow::impls::borrowed_locals; | ||
|
||
pub struct NormalizeArrayLen; | ||
|
||
impl<'tcx> MirPass<'tcx> for NormalizeArrayLen { | ||
fn is_enabled(&self, sess: &rustc_session::Session) -> bool { | ||
// See #105929 | ||
sess.mir_opt_level() >= 4 && sess.opts.unstable_opts.unsound_mir_opts | ||
sess.mir_opt_level() >= 3 | ||
} | ||
|
||
#[instrument(level = "trace", skip(self, tcx, body))] | ||
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { | ||
// early returns for edge cases of highly unrolled functions | ||
if body.basic_blocks.len() > MAX_NUM_BLOCKS { | ||
return; | ||
} | ||
if body.local_decls.len() > MAX_NUM_LOCALS { | ||
return; | ||
} | ||
debug!(def_id = ?body.source.def_id()); | ||
normalize_array_len_calls(tcx, body) | ||
} | ||
} | ||
|
||
pub fn normalize_array_len_calls<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { | ||
// We don't ever touch terminators, so no need to invalidate the CFG cache | ||
let basic_blocks = body.basic_blocks.as_mut_preserves_cfg(); | ||
let local_decls = &mut body.local_decls; | ||
fn normalize_array_len_calls<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { | ||
let param_env = tcx.param_env_reveal_all_normalized(body.source.def_id()); | ||
let borrowed_locals = borrowed_locals(body); | ||
let ssa = SsaLocals::new(tcx, param_env, body, &borrowed_locals); | ||
|
||
// do a preliminary analysis to see if we ever have locals of type `[T;N]` or `&[T;N]` | ||
let mut interesting_locals = BitSet::new_empty(local_decls.len()); | ||
for (local, decl) in local_decls.iter_enumerated() { | ||
match decl.ty.kind() { | ||
ty::Array(..) => { | ||
interesting_locals.insert(local); | ||
} | ||
ty::Ref(.., ty, Mutability::Not) => match ty.kind() { | ||
ty::Array(..) => { | ||
interesting_locals.insert(local); | ||
} | ||
_ => {} | ||
}, | ||
_ => {} | ||
} | ||
} | ||
if interesting_locals.is_empty() { | ||
// we have found nothing to analyze | ||
return; | ||
} | ||
let num_intesting_locals = interesting_locals.count(); | ||
let mut state = FxIndexMap::with_capacity_and_hasher(num_intesting_locals, Default::default()); | ||
let mut patches_scratchpad = | ||
FxIndexMap::with_capacity_and_hasher(num_intesting_locals, Default::default()); | ||
let mut replacements_scratchpad = | ||
FxIndexMap::with_capacity_and_hasher(num_intesting_locals, Default::default()); | ||
for block in basic_blocks { | ||
// make length calls for arrays [T; N] not to decay into length calls for &[T] | ||
// that forbids constant propagation | ||
normalize_array_len_call( | ||
tcx, | ||
block, | ||
local_decls, | ||
&interesting_locals, | ||
&mut state, | ||
&mut patches_scratchpad, | ||
&mut replacements_scratchpad, | ||
); | ||
state.clear(); | ||
patches_scratchpad.clear(); | ||
replacements_scratchpad.clear(); | ||
} | ||
} | ||
let slice_lengths = compute_slice_length(tcx, &ssa, body); | ||
debug!(?slice_lengths); | ||
|
||
struct Patcher<'a, 'tcx> { | ||
tcx: TyCtxt<'tcx>, | ||
patches_scratchpad: &'a FxIndexMap<usize, usize>, | ||
replacements_scratchpad: &'a mut FxIndexMap<usize, Local>, | ||
local_decls: &'a mut IndexVec<Local, LocalDecl<'tcx>>, | ||
statement_idx: usize, | ||
Replacer { tcx, slice_lengths }.visit_body_preserves_cfg(body); | ||
} | ||
|
||
impl<'tcx> Patcher<'_, 'tcx> { | ||
fn patch_expand_statement( | ||
&mut self, | ||
statement: &mut Statement<'tcx>, | ||
) -> Option<std::vec::IntoIter<Statement<'tcx>>> { | ||
let idx = self.statement_idx; | ||
if let Some(len_statemnt_idx) = self.patches_scratchpad.get(&idx).copied() { | ||
let mut statements = Vec::with_capacity(2); | ||
|
||
// we are at statement that performs a cast. The only sound way is | ||
// to create another local that performs a similar copy without a cast and then | ||
// use this copy in the Len operation | ||
|
||
match &statement.kind { | ||
StatementKind::Assign(box ( | ||
.., | ||
Rvalue::Cast( | ||
CastKind::Pointer(ty::adjustment::PointerCast::Unsize), | ||
operand, | ||
_, | ||
), | ||
)) => { | ||
match operand { | ||
Operand::Copy(place) | Operand::Move(place) => { | ||
// create new local | ||
let ty = operand.ty(self.local_decls, self.tcx); | ||
let local_decl = LocalDecl::with_source_info(ty, statement.source_info); | ||
let local = self.local_decls.push(local_decl); | ||
// make it live | ||
let mut make_live_statement = statement.clone(); | ||
make_live_statement.kind = StatementKind::StorageLive(local); | ||
statements.push(make_live_statement); | ||
// copy into it | ||
|
||
let operand = Operand::Copy(*place); | ||
let mut make_copy_statement = statement.clone(); | ||
let assign_to = Place::from(local); | ||
let rvalue = Rvalue::Use(operand); | ||
make_copy_statement.kind = | ||
StatementKind::Assign(Box::new((assign_to, rvalue))); | ||
statements.push(make_copy_statement); | ||
|
||
// to reorder we have to copy and make NOP | ||
statements.push(statement.clone()); | ||
statement.make_nop(); | ||
|
||
self.replacements_scratchpad.insert(len_statemnt_idx, local); | ||
} | ||
_ => { | ||
unreachable!("it's a bug in the implementation") | ||
} | ||
} | ||
} | ||
_ => { | ||
unreachable!("it's a bug in the implementation") | ||
fn compute_slice_length<'tcx>( | ||
tcx: TyCtxt<'tcx>, | ||
ssa: &SsaLocals, | ||
body: &Body<'tcx>, | ||
) -> IndexVec<Local, Option<ty::Const<'tcx>>> { | ||
let mut slice_lengths = IndexVec::from_elem(None, &body.local_decls); | ||
|
||
for (local, rvalue) in ssa.assignments(body) { | ||
match rvalue { | ||
Rvalue::Cast( | ||
CastKind::Pointer(ty::adjustment::PointerCast::Unsize), | ||
operand, | ||
cast_ty, | ||
) => { | ||
let operand_ty = operand.ty(body, tcx); | ||
debug!(?operand_ty); | ||
if let Some(operand_ty) = operand_ty.builtin_deref(true) | ||
&& let ty::Array(_, len) = operand_ty.ty.kind() | ||
&& let Some(cast_ty) = cast_ty.builtin_deref(true) | ||
&& let ty::Slice(..) = cast_ty.ty.kind() | ||
{ | ||
slice_lengths[local] = Some(*len); | ||
} | ||
} | ||
|
||
self.statement_idx += 1; | ||
|
||
Some(statements.into_iter()) | ||
} else if let Some(local) = self.replacements_scratchpad.get(&idx).copied() { | ||
let mut statements = Vec::with_capacity(2); | ||
|
||
match &statement.kind { | ||
StatementKind::Assign(box (into, Rvalue::Len(place))) => { | ||
let add_deref = if let Some(..) = place.as_local() { | ||
false | ||
} else if let Some(..) = place.local_or_deref_local() { | ||
true | ||
} else { | ||
unreachable!("it's a bug in the implementation") | ||
}; | ||
// replace len statement | ||
let mut len_statement = statement.clone(); | ||
let mut place = Place::from(local); | ||
if add_deref { | ||
place = self.tcx.mk_place_deref(place); | ||
} | ||
len_statement.kind = | ||
StatementKind::Assign(Box::new((*into, Rvalue::Len(place)))); | ||
statements.push(len_statement); | ||
|
||
// make temporary dead | ||
let mut make_dead_statement = statement.clone(); | ||
make_dead_statement.kind = StatementKind::StorageDead(local); | ||
statements.push(make_dead_statement); | ||
|
||
// make original statement NOP | ||
statement.make_nop(); | ||
// The length information is stored in the fat pointer, so we treat `operand` as a value. | ||
Rvalue::Use(operand) => { | ||
if let Some(rhs) = operand.place() && let Some(rhs) = rhs.as_local() { | ||
slice_lengths[local] = slice_lengths[rhs]; | ||
} | ||
_ => { | ||
unreachable!("it's a bug in the implementation") | ||
} | ||
// The length information is stored in the fat pointer. | ||
// Reborrowing copies length information from one pointer to the other. | ||
Rvalue::Ref(_, _, rhs) | Rvalue::AddressOf(_, rhs) => { | ||
if let [PlaceElem::Deref] = rhs.projection[..] { | ||
slice_lengths[local] = slice_lengths[rhs.local]; | ||
} | ||
} | ||
|
||
self.statement_idx += 1; | ||
|
||
Some(statements.into_iter()) | ||
} else { | ||
self.statement_idx += 1; | ||
None | ||
_ => {} | ||
} | ||
} | ||
|
||
slice_lengths | ||
} | ||
|
||
fn normalize_array_len_call<'tcx>( | ||
struct Replacer<'tcx> { | ||
tcx: TyCtxt<'tcx>, | ||
block: &mut BasicBlockData<'tcx>, | ||
local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>, | ||
interesting_locals: &BitSet<Local>, | ||
state: &mut FxIndexMap<Local, usize>, | ||
patches_scratchpad: &mut FxIndexMap<usize, usize>, | ||
replacements_scratchpad: &mut FxIndexMap<usize, Local>, | ||
) { | ||
for (statement_idx, statement) in block.statements.iter_mut().enumerate() { | ||
match &mut statement.kind { | ||
StatementKind::Assign(box (place, rvalue)) => { | ||
match rvalue { | ||
Rvalue::Cast( | ||
CastKind::Pointer(ty::adjustment::PointerCast::Unsize), | ||
operand, | ||
cast_ty, | ||
) => { | ||
let Some(local) = place.as_local() else { return }; | ||
match operand { | ||
Operand::Copy(place) | Operand::Move(place) => { | ||
let Some(operand_local) = place.local_or_deref_local() else { return; }; | ||
if !interesting_locals.contains(operand_local) { | ||
return; | ||
} | ||
let operand_ty = local_decls[operand_local].ty; | ||
match (operand_ty.kind(), cast_ty.kind()) { | ||
(ty::Array(of_ty_src, ..), ty::Slice(of_ty_dst)) => { | ||
if of_ty_src == of_ty_dst { | ||
// this is a cast from [T; N] into [T], so we are good | ||
state.insert(local, statement_idx); | ||
} | ||
} | ||
// current way of patching doesn't allow to work with `mut` | ||
( | ||
ty::Ref( | ||
Region(Interned(ReErased, _)), | ||
operand_ty, | ||
Mutability::Not, | ||
), | ||
ty::Ref( | ||
Region(Interned(ReErased, _)), | ||
cast_ty, | ||
Mutability::Not, | ||
), | ||
) => { | ||
match (operand_ty.kind(), cast_ty.kind()) { | ||
// current way of patching doesn't allow to work with `mut` | ||
(ty::Array(of_ty_src, ..), ty::Slice(of_ty_dst)) => { | ||
if of_ty_src == of_ty_dst { | ||
// this is a cast from [T; N] into [T], so we are good | ||
state.insert(local, statement_idx); | ||
} | ||
} | ||
_ => {} | ||
} | ||
} | ||
_ => {} | ||
} | ||
} | ||
_ => {} | ||
} | ||
} | ||
Rvalue::Len(place) => { | ||
let Some(local) = place.local_or_deref_local() else { | ||
return; | ||
}; | ||
if let Some(cast_statement_idx) = state.get(&local).copied() { | ||
patches_scratchpad.insert(cast_statement_idx, statement_idx); | ||
} | ||
} | ||
_ => { | ||
// invalidate | ||
state.remove(&place.local); | ||
} | ||
} | ||
} | ||
_ => {} | ||
} | ||
} | ||
slice_lengths: IndexVec<Local, Option<ty::Const<'tcx>>>, | ||
} | ||
|
||
let mut patcher = Patcher { | ||
tcx, | ||
patches_scratchpad: &*patches_scratchpad, | ||
replacements_scratchpad, | ||
local_decls, | ||
statement_idx: 0, | ||
}; | ||
impl<'tcx> MutVisitor<'tcx> for Replacer<'tcx> { | ||
fn tcx(&self) -> TyCtxt<'tcx> { | ||
self.tcx | ||
} | ||
|
||
block.expand_statements(|st| patcher.patch_expand_statement(st)); | ||
fn visit_rvalue(&mut self, rvalue: &mut Rvalue<'tcx>, loc: Location) { | ||
if let Rvalue::Len(place) = rvalue | ||
&& let [PlaceElem::Deref] = &place.projection[..] | ||
&& let Some(len) = self.slice_lengths[place.local] | ||
{ | ||
*rvalue = Rvalue::Use(Operand::Constant(Box::new(Constant { | ||
span: rustc_span::DUMMY_SP, | ||
user_ty: None, | ||
literal: ConstantKind::from_const(len, self.tcx), | ||
}))); | ||
} | ||
self.super_rvalue(rvalue, loc); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.