Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions compiler/rustc_abi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1406,6 +1406,13 @@ impl WrappingRange {
Self { start: 0, end: size.unsigned_int_max() }
}

pub fn full_signed(size: Size) -> Self {
Self {
start: (size.signed_int_min() as u128) & size.unsigned_int_max(),
end: size.signed_int_max() as u128,
}
}

/// Returns `true` if `v` is contained in the range.
#[inline(always)]
pub fn contains(&self, v: u128) -> bool {
Expand Down Expand Up @@ -1492,6 +1499,11 @@ impl WrappingRange {
Ok(start <= end)
}
}

#[inline]
pub fn no_wraparound(&self, size: Size, signed: bool) -> Result<bool, RangeFull> {
if signed { self.no_signed_wraparound(size) } else { self.no_unsigned_wraparound(size) }
}
}

impl fmt::Debug for WrappingRange {
Expand Down
5 changes: 5 additions & 0 deletions compiler/rustc_middle/src/ty/sty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1486,6 +1486,11 @@ impl<'tcx> Ty<'tcx> {
matches!(self.kind(), Int(_))
}

#[inline]
pub fn is_unsigned(self) -> bool {
matches!(self.kind(), Uint(_))
}

#[inline]
pub fn is_ptr_sized_integral(self) -> bool {
matches!(self.kind(), Int(ty::IntTy::Isize) | Uint(ty::UintTy::Usize))
Expand Down
206 changes: 202 additions & 4 deletions compiler/rustc_mir_transform/src/gvn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,16 @@
//! that contain `AllocId`s.

use std::borrow::Cow;
use std::cmp::Ordering;
use std::hash::{Hash, Hasher};
use std::iter;

use either::Either;
use hashbrown::hash_table::{Entry, HashTable};
use itertools::Itertools as _;
use rustc_abi::{self as abi, BackendRepr, FIRST_VARIANT, FieldIdx, Primitive, Size, VariantIdx};
use rustc_abi::{
self as abi, BackendRepr, FIRST_VARIANT, FieldIdx, Primitive, Size, VariantIdx, WrappingRange,
};
use rustc_arena::DroplessArena;
use rustc_const_eval::const_eval::DummyMachine;
use rustc_const_eval::interpret::{
Expand All @@ -107,6 +111,7 @@ use rustc_middle::mir::interpret::GlobalAlloc;
use rustc_middle::mir::visit::*;
use rustc_middle::mir::*;
use rustc_middle::ty::layout::HasTypingEnv;
use rustc_middle::ty::util::IntTypeExt;
use rustc_middle::ty::{self, Ty, TyCtxt};
use rustc_span::DUMMY_SP;
use smallvec::SmallVec;
Expand Down Expand Up @@ -1369,17 +1374,18 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> {
}
}

if let Some(value) = self.simplify_binary_inner(op, lhs_ty, lhs, rhs) {
let ty = op.ty(self.tcx, lhs_ty, self.ty(rhs));
if let Some(value) = self.simplify_binary_inner(op, ty, lhs_ty, lhs, rhs) {
return Some(value);
}
let ty = op.ty(self.tcx, lhs_ty, self.ty(rhs));
let value = Value::BinaryOp(op, lhs, rhs);
Some(self.insert(ty, value))
}

fn simplify_binary_inner(
&mut self,
op: BinOp,
ty: Ty<'tcx>,
lhs_ty: Ty<'tcx>,
lhs: VnIndex,
rhs: VnIndex,
Expand Down Expand Up @@ -1488,7 +1494,13 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> {
(BinOp::Eq, a, b) if a == b => self.insert_bool(true),
(BinOp::Ne, Left(a), Left(b)) => self.insert_bool(a != b),
(BinOp::Ne, a, b) if a == b => self.insert_bool(false),
_ => return None,
_ => {
if let Some(result) = self.simplify_binary_range(op, lhs_ty, lhs, rhs) {
self.insert_scalar(ty, result)
} else {
return None;
}
}
};

if op.is_overflowing() {
Expand All @@ -1500,6 +1512,192 @@ impl<'body, 'a, 'tcx> VnState<'body, 'a, 'tcx> {
}
}

fn collect_range(&mut self, value: VnIndex) -> Option<WrappingRange> {
let ty = self.ty(value);
let layout = self.ecx.layout_of(ty).ok()?;
if !layout.backend_repr.is_scalar() {
return None;
}
if let Some(constant) = self.eval_to_const(value) {
let scalar = self.ecx.read_scalar(constant).discard_err()?;
let bits = scalar.to_bits(constant.layout.size).discard_err()?;
return Some(WrappingRange { start: bits, end: bits });
}
match self.get(value) {
Value::Discriminant(discr) => {
let ty::Adt(adt, _) = self.ty(discr).kind() else {
return None;
};
if !adt.is_enum() {
return None;
}
let discr_ty = adt.repr().discr_type().to_ty(self.tcx);
let discr_layout = self.ecx.layout_of(discr_ty).ok()?;
let discrs: Vec<_> = adt
.discriminants(self.tcx)
.map(|(_, discr)| ImmTy::from_uint(discr.val, discr_layout))
.sorted_by(|x, y| {
let cmp = self.ecx.binary_op(BinOp::Cmp, x, y).unwrap();
let cmp = cmp.to_scalar_int().unwrap().to_i8();
match cmp {
-1 => Ordering::Less,
0 => Ordering::Equal,
1 => Ordering::Greater,
_ => unreachable!(),
}
})
.collect();
if let [one] = &discrs[..] {
let one = one.to_scalar().to_bits(discr_layout.size).discard_err()?;
Some(WrappingRange { start: one, end: one })
} else if let [start, .., end] = &discrs[..]
&& let Some(start) = start.to_scalar().to_bits(discr_layout.size).discard_err()
&& let Some(end) = end.to_scalar().to_bits(discr_layout.size).discard_err()
&& let range = (WrappingRange { start, end })
&& range.no_wraparound(discr_layout.size, discr_ty.is_signed()) == Ok(true)
{
// Prefer a non-wrapping range if one exists.
Some(range)
} else {
// Use the minimal wrapping range.
let mut pair = iter::zip(discrs.iter().skip(1), discrs.iter());
let (mut start, mut end) = pair.next()?;
let range_size = self.ecx.binary_op(BinOp::Sub, &end, &start).discard_err()?;
let mut range_size =
range_size.to_scalar().to_bits(discr_layout.size).discard_err()?;
for (try_start, try_end) in pair {
let try_size = self
.ecx
.binary_op(BinOp::Sub, &try_end, &try_start)
.discard_err()?
.to_scalar()
.to_bits(discr_layout.size)
.discard_err()?;
if try_size < range_size {
range_size = try_size;
start = try_start;
end = try_end;
}
}
let start = start.to_scalar().to_bits(discr_layout.size).discard_err()?;
let end = end.to_scalar().to_bits(discr_layout.size).discard_err()?;
Some(WrappingRange { start, end })
}
}
Value::Cast { kind, value } => {
if kind == CastKind::IntToInt {
let mut source_range = self.collect_range(value)?;
let source_ty = self.ty(value);
let source_layout = self.ecx.layout_of(source_ty).ok()?;
if layout.size > source_layout.size {
// If the source range wraps around, using the full range may have an optimal
// result for the cast. This is a trade-off, for instance, we don't know which is optimal,
// when cast [0i8, -128i8] to [0u16, 65408u16] (nowrap) or [65408u16, 127u16] (smaller).
if source_ty.is_signed() {
if source_range.no_signed_wraparound(source_layout.size) == Ok(false) {
source_range = WrappingRange::full_signed(source_layout.size);
}
} else if source_range.no_unsigned_wraparound(source_layout.size)
== Ok(false)
{
source_range = WrappingRange::full(source_layout.size);
}
}

let WrappingRange { start, end } = source_range;
let start = ImmTy::from_uint(start, source_layout);
let end = ImmTy::from_uint(end, source_layout);

if layout.size < source_layout.size {
let range_size = self
.ecx
.binary_op(BinOp::Sub, &end, &start)
.discard_err()?
.to_scalar()
.to_bits(source_layout.size)
.discard_err()?
+ 1;
if range_size >= layout.size.bits() as u128 {
return None;
}
}

let start = self.ecx.int_to_int_or_float(&start, layout).discard_err().unwrap();
let start = start.to_scalar().to_bits(layout.size).discard_err()?;
let end = self.ecx.int_to_int_or_float(&end, layout).discard_err().unwrap();
let end = end.to_scalar().to_bits(layout.size).discard_err()?;
Some(WrappingRange { start, end })
} else {
None
}
}
Value::Opaque(_)
| Value::Constant { .. }
| Value::Aggregate(_, _)
| Value::Union(_, _)
| Value::RawPtr { .. }
| Value::Repeat(_, _)
| Value::Address { .. }
| Value::Projection(_, _)
| Value::NullaryOp(_)
| Value::UnaryOp(_, _)
| Value::BinaryOp(_, _, _) => None,
}
}

fn simplify_binary_range(
&mut self,
op: BinOp,
lhs_ty: Ty<'tcx>,
lhs: VnIndex,
rhs: VnIndex,
) -> Option<Scalar> {
if !lhs_ty.is_integral() {
return None;
}
if !matches!(op, BinOp::Lt | BinOp::Le | BinOp::Gt | BinOp::Ge | BinOp::Cmp) {
return None;
}
let layout = self.ecx.layout_of(lhs_ty).ok()?;
let lhs_range = self.collect_range(lhs)?;
let rhs_range = self.collect_range(rhs)?;
// A wrapping or full range cannot be compared.
if lhs_ty.is_signed()
&& (lhs_range.no_signed_wraparound(layout.size) != Ok(true)
|| rhs_range.no_signed_wraparound(layout.size) != Ok(true))
{
return None;
} else if lhs_ty.is_unsigned()
&& (lhs_range.no_unsigned_wraparound(layout.size) != Ok(true)
|| rhs_range.no_unsigned_wraparound(layout.size) != Ok(true))
{
return None;
}

let lhs_start = ImmTy::from_uint(lhs_range.start, layout);
let lhs_end = ImmTy::from_uint(lhs_range.end, layout);
let rhs_start = ImmTy::from_uint(rhs_range.start, layout);
let rhs_end = ImmTy::from_uint(rhs_range.end, layout);

let cmp = self.ecx.binary_op(op, &lhs_start, &rhs_start).discard_err()?.to_scalar();
if rhs_range.start != rhs_range.end
&& cmp != self.ecx.binary_op(op, &lhs_start, &rhs_end).discard_err()?.to_scalar()
{
return None;
}
if lhs_range.start != lhs_range.end {
if cmp != self.ecx.binary_op(op, &lhs_end, &rhs_start).discard_err()?.to_scalar() {
return None;
}
if rhs_range.start != rhs_range.end
&& cmp != self.ecx.binary_op(op, &lhs_end, &rhs_end).discard_err()?.to_scalar()
{
return None;
}
}
Some(cmp)
}

fn simplify_cast(
&mut self,
initial_kind: &mut CastKind,
Expand Down
52 changes: 52 additions & 0 deletions tests/mir-opt/gvn_range.cast_from_signed_wrapping.GVN.diff
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
- // MIR for `cast_from_signed_wrapping` before GVN
+ // MIR for `cast_from_signed_wrapping` after GVN

fn cast_from_signed_wrapping(_1: SignedWrappingA) -> (bool, bool) {
debug s => _1;
let mut _0: (bool, bool);
let _2: u8;
let _3: SignedWrappingA;
let mut _4: i8;
let mut _5: bool;
let mut _6: u8;
let mut _7: bool;
let mut _8: u8;
scope 1 {
debug s => _2;
}

bb0: {
- StorageLive(_2);
+ nop;
StorageLive(_3);
_3 = copy _1;
- _4 = discriminant(_3);
+ _4 = discriminant(_1);
_2 = move _4 as u8 (IntToInt);
StorageDead(_3);
StorageLive(_5);
StorageLive(_6);
_6 = copy _2;
- _5 = Le(move _6, const 128_u8);
+ _5 = const true;
StorageDead(_6);
StorageLive(_7);
StorageLive(_8);
_8 = copy _2;
- _7 = Gt(move _8, const 128_u8);
+ _7 = const false;
StorageDead(_8);
- _0 = (move _5, move _7);
+ _0 = const (true, false);
StorageDead(_7);
StorageDead(_5);
- StorageDead(_2);
+ nop;
return;
}
+ }
+
+ ALLOC0 (size: 2, align: 1) {
+ 01 00 │ ..
}

Loading
Loading