Skip to content

Commit 51a58c5

Browse files
Rollup merge of rust-lang#126232 - RalfJung:dyn-trait-equality, r=oli-obk
interpret: dyn trait metadata check: equate traits in a proper way Hopefully fixes rust-lang/miri#3541... unfortunately we don't have a testcase. The first commit is just a refactor without functional change. r? `@oli-obk`
2 parents c21de3c + 3757136 commit 51a58c5

File tree

11 files changed

+171
-118
lines changed

11 files changed

+171
-118
lines changed

compiler/rustc_const_eval/src/interpret/cast.rs

+1-8
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,6 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
387387
match (&src_pointee_ty.kind(), &dest_pointee_ty.kind()) {
388388
(&ty::Array(_, length), &ty::Slice(_)) => {
389389
let ptr = self.read_pointer(src)?;
390-
// u64 cast is from usize to u64, which is always good
391390
let val = Immediate::new_slice(
392391
ptr,
393392
length.eval_target_usize(*self.tcx, self.param_env),
@@ -405,13 +404,7 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
405404
let (old_data, old_vptr) = val.to_scalar_pair();
406405
let old_data = old_data.to_pointer(self)?;
407406
let old_vptr = old_vptr.to_pointer(self)?;
408-
let (ty, old_trait) = self.get_ptr_vtable(old_vptr)?;
409-
if old_trait != data_a.principal() {
410-
throw_ub!(InvalidVTableTrait {
411-
expected_trait: data_a,
412-
vtable_trait: old_trait,
413-
});
414-
}
407+
let ty = self.get_ptr_vtable_ty(old_vptr, Some(data_a))?;
415408
let new_vptr = self.get_vtable_ptr(ty, data_b.principal())?;
416409
self.write_immediate(Immediate::new_dyn_trait(old_data, new_vptr, self), dest)
417410
}

compiler/rustc_const_eval/src/interpret/eval_context.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -765,10 +765,10 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
765765
}
766766
Ok(Some((full_size, full_align)))
767767
}
768-
ty::Dynamic(_, _, ty::Dyn) => {
768+
ty::Dynamic(expected_trait, _, ty::Dyn) => {
769769
let vtable = metadata.unwrap_meta().to_pointer(self)?;
770770
// Read size and align from vtable (already checks size).
771-
Ok(Some(self.get_vtable_size_and_align(vtable)?))
771+
Ok(Some(self.get_vtable_size_and_align(vtable, Some(expected_trait))?))
772772
}
773773

774774
ty::Slice(_) | ty::Str => {

compiler/rustc_const_eval/src/interpret/intrinsics.rs

+4-2
Original file line numberDiff line numberDiff line change
@@ -432,12 +432,14 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
432432

433433
sym::vtable_size => {
434434
let ptr = self.read_pointer(&args[0])?;
435-
let (size, _align) = self.get_vtable_size_and_align(ptr)?;
435+
// `None` because we don't know which trait to expect here; any vtable is okay.
436+
let (size, _align) = self.get_vtable_size_and_align(ptr, None)?;
436437
self.write_scalar(Scalar::from_target_usize(size.bytes(), self), dest)?;
437438
}
438439
sym::vtable_align => {
439440
let ptr = self.read_pointer(&args[0])?;
440-
let (_size, align) = self.get_vtable_size_and_align(ptr)?;
441+
// `None` because we don't know which trait to expect here; any vtable is okay.
442+
let (_size, align) = self.get_vtable_size_and_align(ptr, None)?;
441443
self.write_scalar(Scalar::from_target_usize(align.bytes(), self), dest)?;
442444
}
443445

compiler/rustc_const_eval/src/interpret/memory.rs

+12-5
Original file line numberDiff line numberDiff line change
@@ -867,19 +867,26 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
867867
.ok_or_else(|| err_ub!(InvalidFunctionPointer(Pointer::new(alloc_id, offset))).into())
868868
}
869869

870-
pub fn get_ptr_vtable(
870+
/// Get the dynamic type of the given vtable pointer.
871+
/// If `expected_trait` is `Some`, it must be a vtable for the given trait.
872+
pub fn get_ptr_vtable_ty(
871873
&self,
872874
ptr: Pointer<Option<M::Provenance>>,
873-
) -> InterpResult<'tcx, (Ty<'tcx>, Option<ty::PolyExistentialTraitRef<'tcx>>)> {
875+
expected_trait: Option<&'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>>,
876+
) -> InterpResult<'tcx, Ty<'tcx>> {
874877
trace!("get_ptr_vtable({:?})", ptr);
875878
let (alloc_id, offset, _tag) = self.ptr_get_alloc_id(ptr)?;
876879
if offset.bytes() != 0 {
877880
throw_ub!(InvalidVTablePointer(Pointer::new(alloc_id, offset)))
878881
}
879-
match self.tcx.try_get_global_alloc(alloc_id) {
880-
Some(GlobalAlloc::VTable(ty, trait_ref)) => Ok((ty, trait_ref)),
881-
_ => throw_ub!(InvalidVTablePointer(Pointer::new(alloc_id, offset))),
882+
let Some(GlobalAlloc::VTable(ty, vtable_trait)) = self.tcx.try_get_global_alloc(alloc_id)
883+
else {
884+
throw_ub!(InvalidVTablePointer(Pointer::new(alloc_id, offset)))
885+
};
886+
if let Some(expected_trait) = expected_trait {
887+
self.check_vtable_for_type(vtable_trait, expected_trait)?;
882888
}
889+
Ok(ty)
883890
}
884891

885892
pub fn alloc_mark_immutable(&mut self, id: AllocId) -> InterpResult<'tcx> {

compiler/rustc_const_eval/src/interpret/place.rs

-49
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ use tracing::{instrument, trace};
99

1010
use rustc_ast::Mutability;
1111
use rustc_middle::mir;
12-
use rustc_middle::ty;
1312
use rustc_middle::ty::layout::{LayoutOf, TyAndLayout};
1413
use rustc_middle::ty::Ty;
1514
use rustc_middle::{bug, span_bug};
@@ -1018,54 +1017,6 @@ where
10181017
let layout = self.layout_of(raw.ty)?;
10191018
Ok(self.ptr_to_mplace(ptr.into(), layout))
10201019
}
1021-
1022-
/// Turn a place with a `dyn Trait` type into a place with the actual dynamic type.
1023-
/// Aso returns the vtable.
1024-
pub(super) fn unpack_dyn_trait(
1025-
&self,
1026-
mplace: &MPlaceTy<'tcx, M::Provenance>,
1027-
expected_trait: &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>,
1028-
) -> InterpResult<'tcx, (MPlaceTy<'tcx, M::Provenance>, Pointer<Option<M::Provenance>>)> {
1029-
assert!(
1030-
matches!(mplace.layout.ty.kind(), ty::Dynamic(_, _, ty::Dyn)),
1031-
"`unpack_dyn_trait` only makes sense on `dyn*` types"
1032-
);
1033-
let vtable = mplace.meta().unwrap_meta().to_pointer(self)?;
1034-
let (ty, vtable_trait) = self.get_ptr_vtable(vtable)?;
1035-
if expected_trait.principal() != vtable_trait {
1036-
throw_ub!(InvalidVTableTrait { expected_trait, vtable_trait });
1037-
}
1038-
// This is a kind of transmute, from a place with unsized type and metadata to
1039-
// a place with sized type and no metadata.
1040-
let layout = self.layout_of(ty)?;
1041-
let mplace =
1042-
MPlaceTy { mplace: MemPlace { meta: MemPlaceMeta::None, ..mplace.mplace }, layout };
1043-
Ok((mplace, vtable))
1044-
}
1045-
1046-
/// Turn a `dyn* Trait` type into an value with the actual dynamic type.
1047-
/// Also returns the vtable.
1048-
pub(super) fn unpack_dyn_star<P: Projectable<'tcx, M::Provenance>>(
1049-
&self,
1050-
val: &P,
1051-
expected_trait: &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>,
1052-
) -> InterpResult<'tcx, (P, Pointer<Option<M::Provenance>>)> {
1053-
assert!(
1054-
matches!(val.layout().ty.kind(), ty::Dynamic(_, _, ty::DynStar)),
1055-
"`unpack_dyn_star` only makes sense on `dyn*` types"
1056-
);
1057-
let data = self.project_field(val, 0)?;
1058-
let vtable = self.project_field(val, 1)?;
1059-
let vtable = self.read_pointer(&vtable.to_op(self)?)?;
1060-
let (ty, vtable_trait) = self.get_ptr_vtable(vtable)?;
1061-
if expected_trait.principal() != vtable_trait {
1062-
throw_ub!(InvalidVTableTrait { expected_trait, vtable_trait });
1063-
}
1064-
// `data` is already the right thing but has the wrong type. So we transmute it.
1065-
let layout = self.layout_of(ty)?;
1066-
let data = data.transmute(layout, self)?;
1067-
Ok((data, vtable))
1068-
}
10691020
}
10701021

10711022
// Some nodes are used a lot. Make sure they don't unintentionally get bigger.

compiler/rustc_const_eval/src/interpret/terminator.rs

+18-19
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use std::borrow::Cow;
22

33
use either::Either;
4+
use rustc_middle::ty::TyCtxt;
45
use tracing::trace;
56

67
use rustc_middle::span_bug;
@@ -827,20 +828,19 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
827828
};
828829

829830
// Obtain the underlying trait we are working on, and the adjusted receiver argument.
830-
let (vptr, dyn_ty, adjusted_receiver) = if let ty::Dynamic(data, _, ty::DynStar) =
831+
let (dyn_trait, dyn_ty, adjusted_recv) = if let ty::Dynamic(data, _, ty::DynStar) =
831832
receiver_place.layout.ty.kind()
832833
{
833-
let (recv, vptr) = self.unpack_dyn_star(&receiver_place, data)?;
834-
let (dyn_ty, _dyn_trait) = self.get_ptr_vtable(vptr)?;
834+
let recv = self.unpack_dyn_star(&receiver_place, data)?;
835835

836-
(vptr, dyn_ty, recv.ptr())
836+
(data.principal(), recv.layout.ty, recv.ptr())
837837
} else {
838838
// Doesn't have to be a `dyn Trait`, but the unsized tail must be `dyn Trait`.
839839
// (For that reason we also cannot use `unpack_dyn_trait`.)
840840
let receiver_tail = self
841841
.tcx
842842
.struct_tail_erasing_lifetimes(receiver_place.layout.ty, self.param_env);
843-
let ty::Dynamic(data, _, ty::Dyn) = receiver_tail.kind() else {
843+
let ty::Dynamic(receiver_trait, _, ty::Dyn) = receiver_tail.kind() else {
844844
span_bug!(
845845
self.cur_span(),
846846
"dynamic call on non-`dyn` type {}",
@@ -851,25 +851,24 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
851851

852852
// Get the required information from the vtable.
853853
let vptr = receiver_place.meta().unwrap_meta().to_pointer(self)?;
854-
let (dyn_ty, dyn_trait) = self.get_ptr_vtable(vptr)?;
855-
if dyn_trait != data.principal() {
856-
throw_ub!(InvalidVTableTrait {
857-
expected_trait: data,
858-
vtable_trait: dyn_trait,
859-
});
860-
}
854+
let dyn_ty = self.get_ptr_vtable_ty(vptr, Some(receiver_trait))?;
861855

862856
// It might be surprising that we use a pointer as the receiver even if this
863857
// is a by-val case; this works because by-val passing of an unsized `dyn
864858
// Trait` to a function is actually desugared to a pointer.
865-
(vptr, dyn_ty, receiver_place.ptr())
859+
(receiver_trait.principal(), dyn_ty, receiver_place.ptr())
866860
};
867861

868862
// Now determine the actual method to call. We can do that in two different ways and
869863
// compare them to ensure everything fits.
870-
let Some(ty::VtblEntry::Method(fn_inst)) =
871-
self.get_vtable_entries(vptr)?.get(idx).copied()
872-
else {
864+
let vtable_entries = if let Some(dyn_trait) = dyn_trait {
865+
let trait_ref = dyn_trait.with_self_ty(*self.tcx, dyn_ty);
866+
let trait_ref = self.tcx.erase_regions(trait_ref);
867+
self.tcx.vtable_entries(trait_ref)
868+
} else {
869+
TyCtxt::COMMON_VTABLE_ENTRIES
870+
};
871+
let Some(ty::VtblEntry::Method(fn_inst)) = vtable_entries.get(idx).copied() else {
873872
// FIXME(fee1-dead) these could be variants of the UB info enum instead of this
874873
throw_ub_custom!(fluent::const_eval_dyn_call_not_a_method);
875874
};
@@ -898,7 +897,7 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
898897
let receiver_ty = Ty::new_mut_ptr(self.tcx.tcx, dyn_ty);
899898
args[0] = FnArg::Copy(
900899
ImmTy::from_immediate(
901-
Scalar::from_maybe_pointer(adjusted_receiver, self).into(),
900+
Scalar::from_maybe_pointer(adjusted_recv, self).into(),
902901
self.layout_of(receiver_ty)?,
903902
)
904903
.into(),
@@ -974,11 +973,11 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
974973
let place = match place.layout.ty.kind() {
975974
ty::Dynamic(data, _, ty::Dyn) => {
976975
// Dropping a trait object. Need to find actual drop fn.
977-
self.unpack_dyn_trait(&place, data)?.0
976+
self.unpack_dyn_trait(&place, data)?
978977
}
979978
ty::Dynamic(data, _, ty::DynStar) => {
980979
// Dropping a `dyn*`. Need to find actual drop fn.
981-
self.unpack_dyn_star(&place, data)?.0
980+
self.unpack_dyn_star(&place, data)?
982981
}
983982
_ => {
984983
debug_assert_eq!(

compiler/rustc_const_eval/src/interpret/traits.rs

+83-18
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1+
use rustc_infer::infer::TyCtxtInferExt;
2+
use rustc_infer::traits::ObligationCause;
13
use rustc_middle::mir::interpret::{InterpResult, Pointer};
24
use rustc_middle::ty::layout::LayoutOf;
3-
use rustc_middle::ty::{self, Ty, TyCtxt};
5+
use rustc_middle::ty::{self, Ty};
46
use rustc_target::abi::{Align, Size};
7+
use rustc_trait_selection::traits::ObligationCtxt;
58
use tracing::trace;
69

710
use super::util::ensure_monomorphic_enough;
8-
use super::{InterpCx, Machine};
11+
use super::{throw_ub, InterpCx, MPlaceTy, Machine, MemPlaceMeta, OffsetMode, Projectable};
912

1013
impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
1114
/// Creates a dynamic vtable for the given type and vtable origin. This is used only for
@@ -33,28 +36,90 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
3336
Ok(vtable_ptr.into())
3437
}
3538

36-
/// Returns a high-level representation of the entries of the given vtable.
37-
pub fn get_vtable_entries(
38-
&self,
39-
vtable: Pointer<Option<M::Provenance>>,
40-
) -> InterpResult<'tcx, &'tcx [ty::VtblEntry<'tcx>]> {
41-
let (ty, poly_trait_ref) = self.get_ptr_vtable(vtable)?;
42-
Ok(if let Some(poly_trait_ref) = poly_trait_ref {
43-
let trait_ref = poly_trait_ref.with_self_ty(*self.tcx, ty);
44-
let trait_ref = self.tcx.erase_regions(trait_ref);
45-
self.tcx.vtable_entries(trait_ref)
46-
} else {
47-
TyCtxt::COMMON_VTABLE_ENTRIES
48-
})
49-
}
50-
5139
pub fn get_vtable_size_and_align(
5240
&self,
5341
vtable: Pointer<Option<M::Provenance>>,
42+
expected_trait: Option<&'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>>,
5443
) -> InterpResult<'tcx, (Size, Align)> {
55-
let (ty, _trait_ref) = self.get_ptr_vtable(vtable)?;
44+
let ty = self.get_ptr_vtable_ty(vtable, expected_trait)?;
5645
let layout = self.layout_of(ty)?;
5746
assert!(layout.is_sized(), "there are no vtables for unsized types");
5847
Ok((layout.size, layout.align.abi))
5948
}
49+
50+
/// Check that the given vtable trait is valid for a pointer/reference/place with the given
51+
/// expected trait type.
52+
pub(super) fn check_vtable_for_type(
53+
&self,
54+
vtable_trait: Option<ty::PolyExistentialTraitRef<'tcx>>,
55+
expected_trait: &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>,
56+
) -> InterpResult<'tcx> {
57+
// Fast path: if they are equal, it's all fine.
58+
if expected_trait.principal() == vtable_trait {
59+
return Ok(());
60+
}
61+
if let (Some(expected_trait), Some(vtable_trait)) =
62+
(expected_trait.principal(), vtable_trait)
63+
{
64+
// Slow path: spin up an inference context to check if these traits are sufficiently equal.
65+
let infcx = self.tcx.infer_ctxt().build();
66+
let ocx = ObligationCtxt::new(&infcx);
67+
let cause = ObligationCause::dummy_with_span(self.cur_span());
68+
// equate the two trait refs after normalization
69+
let expected_trait = ocx.normalize(&cause, self.param_env, expected_trait);
70+
let vtable_trait = ocx.normalize(&cause, self.param_env, vtable_trait);
71+
if ocx.eq(&cause, self.param_env, expected_trait, vtable_trait).is_ok() {
72+
if ocx.select_all_or_error().is_empty() {
73+
// All good.
74+
return Ok(());
75+
}
76+
}
77+
}
78+
throw_ub!(InvalidVTableTrait { expected_trait, vtable_trait });
79+
}
80+
81+
/// Turn a place with a `dyn Trait` type into a place with the actual dynamic type.
82+
pub(super) fn unpack_dyn_trait(
83+
&self,
84+
mplace: &MPlaceTy<'tcx, M::Provenance>,
85+
expected_trait: &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>,
86+
) -> InterpResult<'tcx, MPlaceTy<'tcx, M::Provenance>> {
87+
assert!(
88+
matches!(mplace.layout.ty.kind(), ty::Dynamic(_, _, ty::Dyn)),
89+
"`unpack_dyn_trait` only makes sense on `dyn*` types"
90+
);
91+
let vtable = mplace.meta().unwrap_meta().to_pointer(self)?;
92+
let ty = self.get_ptr_vtable_ty(vtable, Some(expected_trait))?;
93+
// This is a kind of transmute, from a place with unsized type and metadata to
94+
// a place with sized type and no metadata.
95+
let layout = self.layout_of(ty)?;
96+
let mplace = mplace.offset_with_meta(
97+
Size::ZERO,
98+
OffsetMode::Wrapping,
99+
MemPlaceMeta::None,
100+
layout,
101+
self,
102+
)?;
103+
Ok(mplace)
104+
}
105+
106+
/// Turn a `dyn* Trait` type into an value with the actual dynamic type.
107+
pub(super) fn unpack_dyn_star<P: Projectable<'tcx, M::Provenance>>(
108+
&self,
109+
val: &P,
110+
expected_trait: &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>,
111+
) -> InterpResult<'tcx, P> {
112+
assert!(
113+
matches!(val.layout().ty.kind(), ty::Dynamic(_, _, ty::DynStar)),
114+
"`unpack_dyn_star` only makes sense on `dyn*` types"
115+
);
116+
let data = self.project_field(val, 0)?;
117+
let vtable = self.project_field(val, 1)?;
118+
let vtable = self.read_pointer(&vtable.to_op(self)?)?;
119+
let ty = self.get_ptr_vtable_ty(vtable, Some(expected_trait))?;
120+
// `data` is already the right thing but has the wrong type. So we transmute it.
121+
let layout = self.layout_of(ty)?;
122+
let data = data.transmute(layout, self)?;
123+
Ok(data)
124+
}
60125
}

compiler/rustc_const_eval/src/interpret/validity.rs

+7-11
Original file line numberDiff line numberDiff line change
@@ -343,20 +343,16 @@ impl<'rt, 'tcx, M: Machine<'tcx>> ValidityVisitor<'rt, 'tcx, M> {
343343
match tail.kind() {
344344
ty::Dynamic(data, _, ty::Dyn) => {
345345
let vtable = meta.unwrap_meta().to_pointer(self.ecx)?;
346-
// Make sure it is a genuine vtable pointer.
347-
let (_dyn_ty, dyn_trait) = try_validation!(
348-
self.ecx.get_ptr_vtable(vtable),
346+
// Make sure it is a genuine vtable pointer for the right trait.
347+
try_validation!(
348+
self.ecx.get_ptr_vtable_ty(vtable, Some(data)),
349349
self.path,
350350
Ub(DanglingIntPointer(..) | InvalidVTablePointer(..)) =>
351-
InvalidVTablePtr { value: format!("{vtable}") }
351+
InvalidVTablePtr { value: format!("{vtable}") },
352+
Ub(InvalidVTableTrait { expected_trait, vtable_trait }) => {
353+
InvalidMetaWrongTrait { expected_trait, vtable_trait: *vtable_trait }
354+
},
352355
);
353-
// Make sure it is for the right trait.
354-
if dyn_trait != data.principal() {
355-
throw_validation_failure!(
356-
self.path,
357-
InvalidMetaWrongTrait { expected_trait: data, vtable_trait: dyn_trait }
358-
);
359-
}
360356
}
361357
ty::Slice(..) | ty::Str => {
362358
let _len = meta.unwrap_meta().to_target_usize(self.ecx)?;

0 commit comments

Comments
 (0)