diff --git a/compiler/rustc_middle/src/ty/sty.rs b/compiler/rustc_middle/src/ty/sty.rs index c3834607e9236..c282f2211f650 100644 --- a/compiler/rustc_middle/src/ty/sty.rs +++ b/compiler/rustc_middle/src/ty/sty.rs @@ -1218,6 +1218,12 @@ impl<'tcx> Ty<'tcx> { *self.kind() == Str } + /// Returns true if this type is `&str`. The reference's lifetime is ignored. + #[inline] + pub fn is_imm_ref_str(self) -> bool { + matches!(self.kind(), ty::Ref(_, inner, hir::Mutability::Not) if inner.is_str()) + } + #[inline] pub fn is_param(self, index: u32) -> bool { match self.kind() { diff --git a/compiler/rustc_mir_build/src/builder/matches/buckets.rs b/compiler/rustc_mir_build/src/builder/matches/buckets.rs index fce35aa9ef306..0d2e9bf87585d 100644 --- a/compiler/rustc_mir_build/src/builder/matches/buckets.rs +++ b/compiler/rustc_mir_build/src/builder/matches/buckets.rs @@ -314,7 +314,11 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { } ( - TestKind::Eq { value: test_val, .. }, + TestKind::StringEq { value: test_val, .. }, + TestableCase::Constant { value: case_val, kind: PatConstKind::String }, + ) + | ( + TestKind::ScalarEq { value: test_val, .. }, TestableCase::Constant { value: case_val, kind: PatConstKind::Float | PatConstKind::Other, @@ -347,7 +351,8 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { | TestKind::If | TestKind::SliceLen { .. } | TestKind::Range { .. } - | TestKind::Eq { .. } + | TestKind::StringEq { .. } + | TestKind::ScalarEq { .. } | TestKind::Deref { .. }, _, ) => { diff --git a/compiler/rustc_mir_build/src/builder/matches/match_pair.rs b/compiler/rustc_mir_build/src/builder/matches/match_pair.rs index 386ca61a61241..f0114c2193c3e 100644 --- a/compiler/rustc_mir_build/src/builder/matches/match_pair.rs +++ b/compiler/rustc_mir_build/src/builder/matches/match_pair.rs @@ -2,6 +2,7 @@ use std::sync::Arc; use rustc_abi::FieldIdx; use rustc_middle::mir::*; +use rustc_middle::span_bug; use rustc_middle::thir::*; use rustc_middle::ty::{self, Ty, TypeVisitableExt}; @@ -173,9 +174,21 @@ impl<'tcx> MatchPairTree<'tcx> { PatConstKind::IntOrChar } else if pat_ty.is_floating_point() { PatConstKind::Float + } else if pat_ty.is_str() { + // Deref-patterns can cause string-literal patterns to have + // type `str` instead of the usual `&str`. + if !cx.tcx.features().deref_patterns() { + span_bug!( + pattern.span, + "const pattern has type `str` but deref_patterns is not enabled" + ); + } + PatConstKind::String + } else if pat_ty.is_imm_ref_str() { + PatConstKind::String } else { // FIXME(Zalathar): This still covers several different - // categories (e.g. raw pointer, string, pattern-type) + // categories (e.g. raw pointer, pattern-type) // which could be split out into their own kinds. PatConstKind::Other }; diff --git a/compiler/rustc_mir_build/src/builder/matches/mod.rs b/compiler/rustc_mir_build/src/builder/matches/mod.rs index 421065a894119..9080e2ba801bf 100644 --- a/compiler/rustc_mir_build/src/builder/matches/mod.rs +++ b/compiler/rustc_mir_build/src/builder/matches/mod.rs @@ -1290,9 +1290,10 @@ enum PatConstKind { /// These types don't support `SwitchInt` and require an equality test, /// but can also interact with range pattern tests. Float, + /// Constant string values, tested via string equality. + String, /// Any other constant-pattern is usually tested via some kind of equality /// check. Types that might be encountered here include: - /// - `&str` /// - raw pointers derived from integer values /// - pattern types, e.g. `pattern_type!(u32 is 1..)` Other, @@ -1368,14 +1369,20 @@ enum TestKind<'tcx> { /// Test whether a `bool` is `true` or `false`. If, - /// Test for equality with value, possibly after an unsizing coercion to - /// `cast_ty`, - Eq { + /// Tests the place against a string constant using string equality. + StringEq { + /// Constant `&str` value to test against. value: ty::Value<'tcx>, - // Integer types are handled by `SwitchInt`, and constants with ADT - // types and `&[T]` types are converted back into patterns, so this can - // only be `&str` or floats. - cast_ty: Ty<'tcx>, + /// Type of the corresponding pattern node. Usually `&str`, but could + /// be `str` for patterns like `deref!("..."): String`. + pat_ty: Ty<'tcx>, + }, + + /// Tests the place against a constant using scalar equality. + ScalarEq { + value: ty::Value<'tcx>, + /// Type of the corresponding pattern node. + pat_ty: Ty<'tcx>, }, /// Test whether the value falls within an inclusive or exclusive range. diff --git a/compiler/rustc_mir_build/src/builder/matches/test.rs b/compiler/rustc_mir_build/src/builder/matches/test.rs index 972d9f66faddc..c2e39d47a92ca 100644 --- a/compiler/rustc_mir_build/src/builder/matches/test.rs +++ b/compiler/rustc_mir_build/src/builder/matches/test.rs @@ -38,11 +38,11 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { TestableCase::Constant { value: _, kind: PatConstKind::IntOrChar } => { TestKind::SwitchInt } - TestableCase::Constant { value, kind: PatConstKind::Float } => { - TestKind::Eq { value, cast_ty: match_pair.pattern_ty } + TestableCase::Constant { value, kind: PatConstKind::String } => { + TestKind::StringEq { value, pat_ty: match_pair.pattern_ty } } - TestableCase::Constant { value, kind: PatConstKind::Other } => { - TestKind::Eq { value, cast_ty: match_pair.pattern_ty } + TestableCase::Constant { value, kind: PatConstKind::Float | PatConstKind::Other } => { + TestKind::ScalarEq { value, pat_ty: match_pair.pattern_ty } } TestableCase::Range(ref range) => { @@ -141,17 +141,19 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { self.cfg.terminate(block, self.source_info(match_start_span), terminator); } - TestKind::Eq { value, mut cast_ty } => { + TestKind::StringEq { value, pat_ty } => { let tcx = self.tcx; let success_block = target_block(TestBranch::Success); let fail_block = target_block(TestBranch::Failure); - let mut expect_ty = value.ty; - let mut expect = self.literal_operand(test.span, Const::from_ty_value(tcx, value)); + let expected_value_ty = value.ty; + let expected_value_operand = + self.literal_operand(test.span, Const::from_ty_value(tcx, value)); - let mut place = place; + let mut actual_value_ty = pat_ty; + let mut actual_value_place = place; - match cast_ty.kind() { + match pat_ty.kind() { ty::Str => { // String literal patterns may have type `str` if `deref_patterns` is // enabled, in order to allow `deref!("..."): String`. In this case, `value` @@ -172,11 +174,43 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { ref_place, Rvalue::Ref(re_erased, BorrowKind::Shared, place), ); - place = ref_place; - cast_ty = ref_str_ty; + actual_value_place = ref_place; + actual_value_ty = ref_str_ty; } + _ => {} + } + + assert_eq!(expected_value_ty, actual_value_ty); + assert!(actual_value_ty.is_imm_ref_str()); + + // Compare two strings using `::eq`. + // (Interestingly this means that exhaustiveness analysis relies, for soundness, + // on the `PartialEq` impl for `str` to be correct!) + self.string_compare( + block, + success_block, + fail_block, + source_info, + expected_value_operand, + Operand::Copy(actual_value_place), + ); + } + + TestKind::ScalarEq { value, pat_ty } => { + let tcx = self.tcx; + let success_block = target_block(TestBranch::Success); + let fail_block = target_block(TestBranch::Failure); + + let mut expected_value_ty = value.ty; + let mut expected_value_operand = + self.literal_operand(test.span, Const::from_ty_value(tcx, value)); + + let mut actual_value_ty = pat_ty; + let mut actual_value_place = place; + + match pat_ty.kind() { &ty::Pat(base, _) => { - assert_eq!(cast_ty, value.ty); + assert_eq!(pat_ty, value.ty); assert!(base.is_trivially_pure_clone_copy()); let transmuted_place = self.temp(base, test.span); @@ -184,7 +218,11 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { block, self.source_info(scrutinee_span), transmuted_place, - Rvalue::Cast(CastKind::Transmute, Operand::Copy(place), base), + Rvalue::Cast( + CastKind::Transmute, + Operand::Copy(actual_value_place), + base, + ), ); let transmuted_expect = self.temp(base, test.span); @@ -192,54 +230,29 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { block, self.source_info(test.span), transmuted_expect, - Rvalue::Cast(CastKind::Transmute, expect, base), + Rvalue::Cast(CastKind::Transmute, expected_value_operand, base), ); - place = transmuted_place; - expect = Operand::Copy(transmuted_expect); - cast_ty = base; - expect_ty = base; + actual_value_place = transmuted_place; + actual_value_ty = base; + expected_value_operand = Operand::Copy(transmuted_expect); + expected_value_ty = base; } _ => {} } - assert_eq!(expect_ty, cast_ty); - if !cast_ty.is_scalar() { - // Use `PartialEq::eq` instead of `BinOp::Eq` - // (the binop can only handle primitives) - // Make sure that we do *not* call any user-defined code here. - // The only type that can end up here is string literals, which have their - // comparison defined in `core`. - // (Interestingly this means that exhaustiveness analysis relies, for soundness, - // on the `PartialEq` impl for `str` to b correct!) - match *cast_ty.kind() { - ty::Ref(_, deref_ty, _) if deref_ty == self.tcx.types.str_ => {} - _ => { - span_bug!( - source_info.span, - "invalid type for non-scalar compare: {cast_ty}" - ) - } - }; - self.string_compare( - block, - success_block, - fail_block, - source_info, - expect, - Operand::Copy(place), - ); - } else { - self.compare( - block, - success_block, - fail_block, - source_info, - BinOp::Eq, - expect, - Operand::Copy(place), - ); - } + assert_eq!(expected_value_ty, actual_value_ty); + assert!(actual_value_ty.is_scalar()); + + self.compare( + block, + success_block, + fail_block, + source_info, + BinOp::Eq, + expected_value_operand, + Operand::Copy(actual_value_place), + ); } TestKind::Range(ref range) => {