From 51918dcc51360d79cec771ef556ae9181ebaeb2e Mon Sep 17 00:00:00 2001 From: Michael Benfield Date: Mon, 10 Oct 2022 17:29:38 +0000 Subject: [PATCH] rustc_codegen_ssa: Better code generation for niche discriminants. In some cases we can avoid arithmetic before checking whether a niche represents an untagged variant. This is relevant to #101872 --- compiler/rustc_codegen_ssa/src/mir/place.rs | 194 +++++++++++++++----- src/test/codegen/enum-match.rs | 112 +++++++++++ src/test/ui/enum-discriminant/get_discr.rs | 114 ++++++++++++ 3 files changed, 371 insertions(+), 49 deletions(-) create mode 100644 src/test/codegen/enum-match.rs create mode 100644 src/test/ui/enum-discriminant/get_discr.rs diff --git a/compiler/rustc_codegen_ssa/src/mir/place.rs b/compiler/rustc_codegen_ssa/src/mir/place.rs index 9c18df5643f1c..8c72d6d1fbe36 100644 --- a/compiler/rustc_codegen_ssa/src/mir/place.rs +++ b/compiler/rustc_codegen_ssa/src/mir/place.rs @@ -209,7 +209,9 @@ impl<'a, 'tcx, V: CodegenObject> PlaceRef<'tcx, V> { bx: &mut Bx, cast_to: Ty<'tcx>, ) -> V { - let cast_to = bx.cx().immediate_backend_type(bx.cx().layout_of(cast_to)); + let cast_to_layout = bx.cx().layout_of(cast_to); + let cast_to_size = cast_to_layout.layout.size(); + let cast_to = bx.cx().immediate_backend_type(cast_to_layout); if self.layout.abi.is_uninhabited() { return bx.cx().const_undef(cast_to); } @@ -229,7 +231,8 @@ impl<'a, 'tcx, V: CodegenObject> PlaceRef<'tcx, V> { // Read the tag/niche-encoded discriminant from memory. let tag = self.project_field(bx, tag_field); - let tag = bx.load_operand(tag); + let tag_op = bx.load_operand(tag); + let tag_imm = tag_op.immediate(); // Decode the discriminant (specifically if it's niche-encoded). match *tag_encoding { @@ -242,68 +245,161 @@ impl<'a, 'tcx, V: CodegenObject> PlaceRef<'tcx, V> { Int(_, signed) => !tag_scalar.is_bool() && signed, _ => false, }; - bx.intcast(tag.immediate(), cast_to, signed) + bx.intcast(tag_imm, cast_to, signed) } TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start } => { - // Rebase from niche values to discriminants, and check - // whether the result is in range for the niche variants. - let niche_llty = bx.cx().immediate_backend_type(tag.layout); - let tag = tag.immediate(); - - // We first compute the "relative discriminant" (wrt `niche_variants`), - // that is, if `n = niche_variants.end() - niche_variants.start()`, - // we remap `niche_start..=niche_start + n` (which may wrap around) - // to (non-wrap-around) `0..=n`, to be able to check whether the - // discriminant corresponds to a niche variant with one comparison. - // We also can't go directly to the (variant index) discriminant - // and check that it is in the range `niche_variants`, because - // that might not fit in the same type, on top of needing an extra - // comparison (see also the comment on `let niche_discr`). - let relative_discr = if niche_start == 0 { - // Avoid subtracting `0`, which wouldn't work for pointers. - // FIXME(eddyb) check the actual primitive type here. - tag + // Cast to an integer so we don't have to treat a pointer as a + // special case. + let (tag, tag_llty) = if tag_scalar.primitive().is_ptr() { + let t = bx.type_isize(); + let tag = bx.ptrtoint(tag_imm, t); + (tag, t) } else { - bx.sub(tag, bx.cx().const_uint_big(niche_llty, niche_start)) + (tag_imm, bx.cx().immediate_backend_type(tag_op.layout)) }; + + let tag_size = tag_scalar.size(bx.cx()); + let max_unsigned = tag_size.unsigned_int_max(); + let max_signed = tag_size.signed_int_max() as u128; + let min_signed = max_signed + 1; let relative_max = niche_variants.end().as_u32() - niche_variants.start().as_u32(); - let is_niche = if relative_max == 0 { - // Avoid calling `const_uint`, which wouldn't work for pointers. - // Also use canonical == 0 instead of non-canonical u<= 0. - // FIXME(eddyb) check the actual primitive type here. - bx.icmp(IntPredicate::IntEQ, relative_discr, bx.cx().const_null(niche_llty)) + let niche_end = niche_start.wrapping_add(relative_max as u128) & max_unsigned; + let range = tag_scalar.valid_range(bx.cx()); + + let sle = |lhs: u128, rhs: u128| -> bool { + // Signed and unsigned comparisons give the same results, + // except that in signed comparisons an integer with the + // sign bit set is less than one with the sign bit clear. + // Toggle the sign bit to do a signed comparison. + (lhs ^ min_signed) <= (rhs ^ min_signed) + }; + + // We have a subrange `niche_start..=niche_end` inside `range`. + // If the value of the tag is inside this subrange, it's a + // "niche value", an increment of the discriminant. Otherwise it + // indicates the untagged variant. + // A general algorithm to extract the discriminant from the tag + // is: + // relative_tag = tag - niche_start + // is_niche = relative_tag <= (ule) relative_max + // discr = if is_niche { + // cast(relative_tag) + niche_variants.start() + // } else { + // untagged_variant + // } + // However, we will likely be able to emit simpler code. + + // Find the least and greatest values in `range`, considered + // both as signed and unsigned. + let (low_unsigned, high_unsigned) = if range.start <= range.end { + (range.start, range.end) + } else { + (0, max_unsigned) + }; + let (low_signed, high_signed) = if sle(range.start, range.end) { + (range.start, range.end) } else { - let relative_max = bx.cx().const_uint(niche_llty, relative_max as u64); - bx.icmp(IntPredicate::IntULE, relative_discr, relative_max) + (min_signed, max_signed) + }; + + let niches_ule = niche_start <= niche_end; + let niches_sle = sle(niche_start, niche_end); + let cast_smaller = cast_to_size <= tag_size; + + // In the algorithm above, we can change + // cast(relative_tag) + niche_variants.start() + // into + // cast(tag) + (niche_variants.start() - niche_start) + // if either the casted type is no larger than the original + // type, or if the niche values are contiguous (in either the + // signed or unsigned sense). + let can_incr_after_cast = cast_smaller || niches_ule || niches_sle; + + let data_for_boundary_niche = || -> Option<(IntPredicate, u128)> { + if !can_incr_after_cast { + None + } else if niche_start == low_unsigned { + Some((IntPredicate::IntULE, niche_end)) + } else if niche_end == high_unsigned { + Some((IntPredicate::IntUGE, niche_start)) + } else if niche_start == low_signed { + Some((IntPredicate::IntSLE, niche_end)) + } else if niche_end == high_signed { + Some((IntPredicate::IntSGE, niche_start)) + } else { + None + } }; - // NOTE(eddyb) this addition needs to be performed on the final - // type, in case the niche itself can't represent all variant - // indices (e.g. `u8` niche with more than `256` variants, - // but enough uninhabited variants so that the remaining variants - // fit in the niche). - // In other words, `niche_variants.end - niche_variants.start` - // is representable in the niche, but `niche_variants.end` - // might not be, in extreme cases. - let niche_discr = { - let relative_discr = if relative_max == 0 { - // HACK(eddyb) since we have only one niche, we know which - // one it is, and we can avoid having a dynamic value here. - bx.cx().const_uint(cast_to, 0) + let (is_niche, tagged_discr, delta) = if relative_max == 0 { + // Best case scenario: only one tagged variant. This will + // likely become just a comparison and a jump. + // The algorithm is: + // is_niche = tag == niche_start + // discr = if is_niche { + // niche_start + // } else { + // untagged_variant + // } + let niche_start = bx.cx().const_uint_big(tag_llty, niche_start); + let is_niche = bx.icmp(IntPredicate::IntEQ, tag, niche_start); + let tagged_discr = + bx.cx().const_uint(cast_to, niche_variants.start().as_u32() as u64); + (is_niche, tagged_discr, 0) + } else if let Some((predicate, constant)) = data_for_boundary_niche() { + // The niche values are either the lowest or the highest in + // `range`. We can avoid the first subtraction in the + // algorithm. + // The algorithm is now this: + // is_niche = tag <= niche_end + // discr = if is_niche { + // cast(tag) + (niche_variants.start() - niche_start) + // } else { + // untagged_variant + // } + // (the first line may instead be tag >= niche_start, + // and may be a signed or unsigned comparison) + let is_niche = + bx.icmp(predicate, tag, bx.cx().const_uint_big(tag_llty, constant)); + let cast_tag = if cast_smaller { + bx.intcast(tag, cast_to, false) + } else if niches_ule { + bx.zext(tag, cast_to) } else { - bx.intcast(relative_discr, cast_to, false) + bx.sext(tag, cast_to) }; - bx.add( + + let delta = (niche_variants.start().as_u32() as u128).wrapping_sub(niche_start); + (is_niche, cast_tag, delta) + } else { + // The special cases don't apply, so we'll have to go with + // the general algorithm. + let relative_discr = bx.sub(tag, bx.cx().const_uint_big(tag_llty, niche_start)); + let cast_tag = bx.intcast(relative_discr, cast_to, false); + let is_niche = bx.icmp( + IntPredicate::IntULE, relative_discr, - bx.cx().const_uint(cast_to, niche_variants.start().as_u32() as u64), - ) + bx.cx().const_uint(tag_llty, relative_max as u64), + ); + (is_niche, cast_tag, niche_variants.start().as_u32() as u128) }; - bx.select( + let tagged_discr = if delta == 0 { + tagged_discr + } else { + bx.add(tagged_discr, bx.cx().const_uint_big(cast_to, delta)) + }; + + let discr = bx.select( is_niche, - niche_discr, + tagged_discr, bx.cx().const_uint(cast_to, untagged_variant.as_u32() as u64), - ) + ); + + // In principle we could insert assumes on the possible range of `discr`, but + // currently in LLVM this seems to be a pessimization. + + discr } } } diff --git a/src/test/codegen/enum-match.rs b/src/test/codegen/enum-match.rs new file mode 100644 index 0000000000000..efab189fd7b8f --- /dev/null +++ b/src/test/codegen/enum-match.rs @@ -0,0 +1,112 @@ +// compile-flags: -Copt-level=1 +// only-x86_64 + +#![crate_type = "lib"] + +// Check each of the 3 cases for `codegen_get_discr`. + +// Case 0: One tagged variant. +pub enum Enum0 { + A(bool), + B, +} + +// CHECK: define i8 @match0{{.*}} +// CHECK-NEXT: start: +// CHECK-NEXT: %1 = icmp eq i8 %0, 2 +// CHECK-NEXT: %2 = and i8 %0, 1 +// CHECK-NEXT: %.0 = select i1 %1, i8 13, i8 %2 +#[no_mangle] +pub fn match0(e: Enum0) -> u8 { + use Enum0::*; + match e { + A(b) => b as u8, + B => 13, + } +} + +// Case 1: Niche values are on a boundary for `range`. +pub enum Enum1 { + A(bool), + B, + C, +} + +// CHECK: define i8 @match1{{.*}} +// CHECK-NEXT: start: +// CHECK-NEXT: %1 = icmp ugt i8 %0, 1 +// CHECK-NEXT: %2 = zext i8 %0 to i64 +// CHECK-NEXT: %3 = add nsw i64 %2, -1 +// CHECK-NEXT: %_2 = select i1 %1, i64 %3, i64 0 +// CHECK-NEXT: switch i64 %_2, label {{.*}} [ +#[no_mangle] +pub fn match1(e: Enum1) -> u8 { + use Enum1::*; + match e { + A(b) => b as u8, + B => 13, + C => 100, + } +} + +// Case 2: Special cases don't apply. +pub enum X { + _2=2, _3, _4, _5, _6, _7, _8, _9, _10, _11, + _12, _13, _14, _15, _16, _17, _18, _19, _20, + _21, _22, _23, _24, _25, _26, _27, _28, _29, + _30, _31, _32, _33, _34, _35, _36, _37, _38, + _39, _40, _41, _42, _43, _44, _45, _46, _47, + _48, _49, _50, _51, _52, _53, _54, _55, _56, + _57, _58, _59, _60, _61, _62, _63, _64, _65, + _66, _67, _68, _69, _70, _71, _72, _73, _74, + _75, _76, _77, _78, _79, _80, _81, _82, _83, + _84, _85, _86, _87, _88, _89, _90, _91, _92, + _93, _94, _95, _96, _97, _98, _99, _100, _101, + _102, _103, _104, _105, _106, _107, _108, _109, + _110, _111, _112, _113, _114, _115, _116, _117, + _118, _119, _120, _121, _122, _123, _124, _125, + _126, _127, _128, _129, _130, _131, _132, _133, + _134, _135, _136, _137, _138, _139, _140, _141, + _142, _143, _144, _145, _146, _147, _148, _149, + _150, _151, _152, _153, _154, _155, _156, _157, + _158, _159, _160, _161, _162, _163, _164, _165, + _166, _167, _168, _169, _170, _171, _172, _173, + _174, _175, _176, _177, _178, _179, _180, _181, + _182, _183, _184, _185, _186, _187, _188, _189, + _190, _191, _192, _193, _194, _195, _196, _197, + _198, _199, _200, _201, _202, _203, _204, _205, + _206, _207, _208, _209, _210, _211, _212, _213, + _214, _215, _216, _217, _218, _219, _220, _221, + _222, _223, _224, _225, _226, _227, _228, _229, + _230, _231, _232, _233, _234, _235, _236, _237, + _238, _239, _240, _241, _242, _243, _244, _245, + _246, _247, _248, _249, _250, _251, _252, _253, +} + +pub enum Enum2 { + A(X), + B, + C, + D, + E, +} + +// CHECK: define i8 @match2{{.*}} +// CHECK-NEXT: start: +// CHECK-NEXT: %1 = add i8 %0, 2 +// CHECK-NEXT: %2 = zext i8 %1 to i64 +// CHECK-NEXT: %3 = icmp ult i8 %1, 4 +// CHECK-NEXT: %4 = add nuw nsw i64 %2, 1 +// CHECK-NEXT: %_2 = select i1 %3, i64 %4, i64 0 +// CHECK-NEXT: switch i64 %_2, label {{.*}} [ +#[no_mangle] +pub fn match2(e: Enum2) -> u8 { + use Enum2::*; + match e { + A(b) => b as u8, + B => 13, + C => 100, + D => 200, + E => 250, + } +} diff --git a/src/test/ui/enum-discriminant/get_discr.rs b/src/test/ui/enum-discriminant/get_discr.rs new file mode 100644 index 0000000000000..71eea4e0f78af --- /dev/null +++ b/src/test/ui/enum-discriminant/get_discr.rs @@ -0,0 +1,114 @@ +// run-pass + +// Now that there are several variations on the code generated in +// `codegen_get_discr`, let's make sure the various cases yield the correct +// result. + +// To get the discriminant of an E value, there are no shortcuts - we must +// do the full algorithm. +#[repr(u8)] +pub enum X1 { + _1 = 1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, + _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, + _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, + _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59, _60, _61, _62, _63, _64, + _65, _66, _67, _68, _69, _70, _71, _72, _73, _74, _75, _76, _77, _78, _79, _80, + _81, _82, _83, _84, _85, _86, _87, _88, _89, _90, _91, _92, _93, _94, _95, _96, + _97, _98, _99, _100, _101, _102, _103, _104, _105, _106, _107, _108, _109, _110, _111, _112, + _113, _114, _115, _116, _117, _118, _119, _120, _121, _122, _123, _124, _125, _126, _127, _128, + _129, _130, _131, _132, _133, _134, _135, _136, _137, _138, _139, _140, _141, _142, _143, _144, + _145, _146, _147, _148, _149, _150, _151, _152, _153, _154, _155, _156, _157, _158, _159, _160, + _161, _162, _163, _164, _165, _166, _167, _168, _169, _170, _171, _172, _173, _174, _175, _176, + _177, _178, _179, _180, _181, _182, _183, _184, _185, _186, _187, _188, _189, _190, _191, _192, + _193, _194, _195, _196, _197, _198, _199, _200, _201, _202, _203, _204, _205, _206, _207, _208, + _209, _210, _211, _212, _213, _214, _215, _216, _217, _218, _219, _220, _221, _222, _223, _224, + _225, _226, _227, _228, _229, _230, _231, _232, _233, _234, _235, _236, _237, _238, _239, _240, + _241, _242, _243, _244, _245, _246, _247, _248, _249, _250, _251, _252, _253, _254, +} + +#[repr(i8)] +pub enum X2 { + _1 = -1, _2 = 0, _3 = 1, +} + +#[repr(i8)] +pub enum X3 { + _1 = -128, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, + _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, + _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, + _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59, _60, _61, _62, _63, _64, + _65, _66, _67, _68, _69, _70, _71, _72, _73, _74, _75, _76, _77, _78, _79, _80, + _81, _82, _83, _84, _85, _86, _87, _88, _89, _90, _91, _92, _93, _94, _95, _96, + _97, _98, _99, _100, _101, _102, _103, _104, _105, _106, _107, _108, _109, _110, _111, _112, + _113, _114, _115, _116, _117, _118, _119, _120, _121, _122, _123, _124, _125, _126, _127, _128, + _129, _130, _131, _132, _133, _134, _135, _136, _137, _138, _139, _140, _141, _142, _143, _144, + _145, _146, _147, _148, _149, _150, _151, _152, _153, _154, _155, _156, _157, _158, _159, _160, + _161, _162, _163, _164, _165, _166, _167, _168, _169, _170, _171, _172, _173, _174, _175, _176, + _177, _178, _179, _180, _181, _182, _183, _184, _185, _186, _187, _188, _189, _190, _191, _192, + _193, _194, _195, _196, _197, _198, _199, _200, _201, _202, _203, _204, _205, _206, _207, _208, + _209, _210, _211, _212, _213, _214, _215, _216, _217, _218, _219, _220, _221, _222, _223, _224, + _225, _226, _227, _228, _229, _230, _231, _232, _233, _234, _235, _236, _237, _238, _239, _240, + _241, _242, _243, _244, _245, _246, _247, _248, _249, _250, _251, _252, _253, _254, +} + +#[repr(i8)] +pub enum X4 { + _1 = -126, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, + _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, + _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, + _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59, _60, _61, _62, _63, _64, + _65, _66, _67, _68, _69, _70, _71, _72, _73, _74, _75, _76, _77, _78, _79, _80, + _81, _82, _83, _84, _85, _86, _87, _88, _89, _90, _91, _92, _93, _94, _95, _96, + _97, _98, _99, _100, _101, _102, _103, _104, _105, _106, _107, _108, _109, _110, _111, _112, + _113, _114, _115, _116, _117, _118, _119, _120, _121, _122, _123, _124, _125, _126, _127, _128, + _129, _130, _131, _132, _133, _134, _135, _136, _137, _138, _139, _140, _141, _142, _143, _144, + _145, _146, _147, _148, _149, _150, _151, _152, _153, _154, _155, _156, _157, _158, _159, _160, + _161, _162, _163, _164, _165, _166, _167, _168, _169, _170, _171, _172, _173, _174, _175, _176, + _177, _178, _179, _180, _181, _182, _183, _184, _185, _186, _187, _188, _189, _190, _191, _192, + _193, _194, _195, _196, _197, _198, _199, _200, _201, _202, _203, _204, _205, _206, _207, _208, + _209, _210, _211, _212, _213, _214, _215, _216, _217, _218, _219, _220, _221, _222, _223, _224, + _225, _226, _227, _228, _229, _230, _231, _232, _233, _234, _235, _236, _237, _238, _239, _240, + _241, _242, _243, _244, _245, _246, _247, _248, _249, _250, _251, _252, _253, _254, +} + +pub enum E { + A(X), + B, + C, +} + +pub fn match_e(e: E) -> u8 { + use E::*; + match e { + A(_) => 0, + B => 1, + C => 2, + } +} + +fn main() { + assert_eq!(match_e(E::A(X1::_1)), 0); + assert_eq!(match_e(E::A(X1::_2)), 0); + assert_eq!(match_e(E::A(X1::_254)), 0); + assert_eq!(match_e(E::::B), 1); + assert_eq!(match_e(E::::C), 2); + assert_eq!(match_e(E::A(X2::_1)), 0); + assert_eq!(match_e(E::A(X2::_2)), 0); + assert_eq!(match_e(E::A(X2::_3)), 0); + assert_eq!(match_e(E::::B), 1); + assert_eq!(match_e(E::::C), 2); + assert_eq!(match_e(E::A(X3::_1)), 0); + assert_eq!(match_e(E::A(X3::_2)), 0); + assert_eq!(match_e(E::A(X3::_254)), 0); + assert_eq!(match_e(E::::B), 1); + assert_eq!(match_e(E::::C), 2); + assert_eq!(match_e(E::A(X4::_1)), 0); + assert_eq!(match_e(E::A(X4::_2)), 0); + assert_eq!(match_e(E::A(X4::_254)), 0); + assert_eq!(match_e(E::::B), 1); + assert_eq!(match_e(E::::C), 2); + assert_eq!(match_e(E::A(false)), 0); + assert_eq!(match_e(E::A(true)), 0); + assert_eq!(match_e(E::::B), 1); + assert_eq!(match_e(E::::C), 2); +}