diff --git a/crates/rustc_codegen_spirv/src/spirv_type.rs b/crates/rustc_codegen_spirv/src/spirv_type.rs index 0c8ce42ba0..d674f2542f 100644 --- a/crates/rustc_codegen_spirv/src/spirv_type.rs +++ b/crates/rustc_codegen_spirv/src/spirv_type.rs @@ -182,35 +182,15 @@ impl SpirvType<'_> { Self::Vector { element, count } => cx.emit_global().type_vector_id(id, element, count), Self::Matrix { element, count } => cx.emit_global().type_matrix_id(id, element, count), Self::Array { element, count } => { - // ArrayStride decoration wants in *bytes* - let element_size = cx - .lookup_type(element) - .sizeof(cx) - .expect("Element of sized array must be sized") - .bytes(); - let mut emit = cx.emit_global(); - let result = emit.type_array_id(id, element, count.def_cx(cx)); - emit.decorate( - result, - Decoration::ArrayStride, - iter::once(Operand::LiteralBit32(element_size as u32)), - ); + let result = cx + .emit_global() + .type_array_id(id, element, count.def_cx(cx)); + Self::decorate_array_stride(result, element, cx); result } Self::RuntimeArray { element } => { - let mut emit = cx.emit_global(); - let result = emit.type_runtime_array_id(id, element); - // ArrayStride decoration wants in *bytes* - let element_size = cx - .lookup_type(element) - .sizeof(cx) - .expect("Element of sized array must be sized") - .bytes(); - emit.decorate( - result, - Decoration::ArrayStride, - iter::once(Operand::LiteralBit32(element_size as u32)), - ); + let result = cx.emit_global().type_runtime_array_id(id, element); + Self::decorate_array_stride(result, element, cx); result } Self::Pointer { pointee } => { @@ -278,6 +258,19 @@ impl SpirvType<'_> { result } + fn decorate_array_stride(result: u32, element: u32, cx: &CodegenCx<'_>) { + let mut emit = cx.emit_global(); + let ty = cx.lookup_type(element); + if let Some(element_size) = ty.physical_size(cx) { + // ArrayStride decoration wants in *bytes* + emit.decorate( + result, + Decoration::ArrayStride, + iter::once(Operand::LiteralBit32(element_size.bytes() as u32)), + ); + } + } + /// `def_with_id` is used by the `RecursivePointeeCache` to handle `OpTypeForwardPointer`: when /// emitting the subsequent `OpTypePointer`, the ID is already known and must be re-used. pub fn def_with_id(self, cx: &CodegenCx<'_>, def_span: Span, id: Word) -> Word { @@ -386,6 +379,35 @@ impl SpirvType<'_> { } } + /// Get the physical size of the type needed for explicit layout decorations. + #[allow(clippy::match_same_arms)] + pub fn physical_size(&self, cx: &CodegenCx<'_>) -> Option { + match *self { + // TODO(jwollen) Handle physical pointers (PhysicalStorageBuffer) + Self::Pointer { .. } => None, + + Self::Adt { size, .. } => size, + + Self::Array { element, count } => Some( + cx.lookup_type(element).physical_size(cx)? + * cx.builder + .lookup_const_scalar(count) + .unwrap() + .try_into() + .unwrap(), + ), + + // Always unsized types + Self::InterfaceBlock { .. } | Self::RayQueryKhr | Self::SampledImage { .. } => None, + + // Descriptor types + Self::Image { .. } | Self::AccelerationStructureKhr | Self::Sampler => None, + + // Primitive types + ty => ty.sizeof(cx), + } + } + /// Replace `&[T]` fields with `&'tcx [T]` ones produced by calling /// `tcx.arena.dropless.alloc_slice(...)` - this is done late for two reasons: /// 1. it avoids allocating in the arena when the cache would be hit anyway, diff --git a/tests/ui/dis/asm_op_decorate.stderr b/tests/ui/dis/asm_op_decorate.stderr index 0eec323876..7364346cab 100644 --- a/tests/ui/dis/asm_op_decorate.stderr +++ b/tests/ui/dis/asm_op_decorate.stderr @@ -13,17 +13,16 @@ OpExecutionMode %1 OriginUpperLeft %2 = OpString "$OPSTRING_FILENAME/asm_op_decorate.rs" OpName %3 "asm_op_decorate::main" OpName %4 "asm_op_decorate::add_decorate" -OpDecorate %5 ArrayStride 4 -OpDecorate %6 Binding 0 -OpDecorate %6 DescriptorSet 0 -%7 = OpTypeVoid -%8 = OpTypeFunction %7 -%9 = OpTypeFloat 32 -%10 = OpTypeImage %9 2D 0 0 0 1 Unknown -%11 = OpTypeSampledImage %10 -%12 = OpTypePointer UniformConstant %11 -%5 = OpTypeRuntimeArray %11 -%13 = OpTypePointer UniformConstant %5 -%6 = OpVariable %13 UniformConstant +OpDecorate %5 Binding 0 +OpDecorate %5 DescriptorSet 0 +%6 = OpTypeVoid +%7 = OpTypeFunction %6 +%8 = OpTypeFloat 32 +%9 = OpTypeImage %8 2D 0 0 0 1 Unknown +%10 = OpTypeSampledImage %9 +%11 = OpTypePointer UniformConstant %10 +%12 = OpTypeRuntimeArray %10 +%13 = OpTypePointer UniformConstant %12 +%5 = OpVariable %13 UniformConstant %14 = OpTypeInt 32 0 %15 = OpConstant %14 1