diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td index 6b2e4189aea02..838f7cc70b0cf 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -357,6 +357,7 @@ def SPV_EXT_shader_atomic_float_add : I32EnumAttrCase<"SPV_EXT_shader_atomi def SPV_EXT_shader_atomic_float_min_max : I32EnumAttrCase<"SPV_EXT_shader_atomic_float_min_max", 1009>; def SPV_EXT_shader_image_int64 : I32EnumAttrCase<"SPV_EXT_shader_image_int64", 1010>; def SPV_EXT_shader_atomic_float16_add : I32EnumAttrCase<"SPV_EXT_shader_atomic_float16_add", 1011>; +def SPV_EXT_mesh_shader : I32EnumAttrCase<"SPV_EXT_mesh_shader", 1012>; def SPV_AMD_gpu_shader_half_float_fetch : I32EnumAttrCase<"SPV_AMD_gpu_shader_half_float_fetch", 2000>; def SPV_AMD_shader_ballot : I32EnumAttrCase<"SPV_AMD_shader_ballot", 2001>; @@ -443,6 +444,7 @@ def SPIRV_ExtensionAttr : SPV_EXT_shader_stencil_export, SPV_EXT_shader_viewport_index_layer, SPV_EXT_shader_atomic_float_add, SPV_EXT_shader_atomic_float_min_max, SPV_EXT_shader_image_int64, SPV_EXT_shader_atomic_float16_add, + SPV_EXT_mesh_shader, SPV_AMD_gpu_shader_half_float_fetch, SPV_AMD_shader_ballot, SPV_AMD_shader_explicit_vertex_parameter, SPV_AMD_shader_fragment_mask, SPV_AMD_shader_image_load_store_lod, SPV_AMD_texture_gather_bias_lod, @@ -1207,6 +1209,12 @@ def SPIRV_C_MeshShadingNV : I32EnumAttrCase<"MeshS Extension<[SPV_NV_mesh_shader]> ]; } +def SPIRV_C_MeshShadingEXT : I32EnumAttrCase<"MeshShadingEXT", 5283> { + list implies = [SPIRV_C_Shader]; + list availability = [ + Extension<[SPV_EXT_mesh_shader]> + ]; +} def SPIRV_C_FragmentDensityEXT : I32EnumAttrCase<"FragmentDensityEXT", 5291> { list implies = [SPIRV_C_Shader]; list availability = [ @@ -1436,7 +1444,7 @@ def SPIRV_CapabilityAttr : SPIRV_C_StorageBuffer8BitAccess, SPIRV_C_StoragePushConstant8, SPIRV_C_DenormPreserve, SPIRV_C_DenormFlushToZero, SPIRV_C_SignedZeroInfNanPreserve, SPIRV_C_RoundingModeRTE, SPIRV_C_RoundingModeRTZ, SPIRV_C_ImageFootprintNV, - SPIRV_C_FragmentBarycentricKHR, SPIRV_C_ComputeDerivativeGroupQuadsNV, + SPIRV_C_FragmentBarycentricKHR, SPIRV_C_MeshShadingEXT, SPIRV_C_ComputeDerivativeGroupQuadsNV, SPIRV_C_GroupNonUniformPartitionedNV, SPIRV_C_VulkanMemoryModel, SPIRV_C_VulkanMemoryModelDeviceScope, SPIRV_C_ComputeDerivativeGroupLinearNV, SPIRV_C_BindlessTextureNV, SPIRV_C_SubgroupShuffleINTEL, @@ -1576,7 +1584,7 @@ def SPIRV_BI_InstanceId : I32EnumAttrCase<"InstanceId", 6> { } def SPIRV_BI_PrimitiveId : I32EnumAttrCase<"PrimitiveId", 7> { list availability = [ - Capability<[SPIRV_C_Geometry, SPIRV_C_MeshShadingNV, SPIRV_C_RayTracingKHR, SPIRV_C_RayTracingNV, SPIRV_C_Tessellation]> + Capability<[SPIRV_C_Geometry, SPIRV_C_MeshShadingNV, SPIRV_C_RayTracingKHR, SPIRV_C_RayTracingNV, SPIRV_C_MeshShadingEXT, SPIRV_C_Tessellation]> ]; } def SPIRV_BI_InvocationId : I32EnumAttrCase<"InvocationId", 8> { @@ -1586,12 +1594,12 @@ def SPIRV_BI_InvocationId : I32EnumAttrCase<"InvocationId", 8> { } def SPIRV_BI_Layer : I32EnumAttrCase<"Layer", 9> { list availability = [ - Capability<[SPIRV_C_Geometry, SPIRV_C_MeshShadingNV, SPIRV_C_ShaderLayer, SPIRV_C_ShaderViewportIndexLayerEXT]> + Capability<[SPIRV_C_Geometry, SPIRV_C_MeshShadingNV, SPIRV_C_MeshShadingEXT, SPIRV_C_ShaderLayer, SPIRV_C_ShaderViewportIndexLayerEXT]> ]; } def SPIRV_BI_ViewportIndex : I32EnumAttrCase<"ViewportIndex", 10> { list availability = [ - Capability<[SPIRV_C_MeshShadingNV, SPIRV_C_MultiViewport, SPIRV_C_ShaderViewportIndex, SPIRV_C_ShaderViewportIndexLayerEXT]> + Capability<[SPIRV_C_MeshShadingNV, SPIRV_C_MeshShadingEXT, SPIRV_C_MultiViewport, SPIRV_C_ShaderViewportIndex, SPIRV_C_ShaderViewportIndexLayerEXT]> ]; } def SPIRV_BI_TessLevelOuter : I32EnumAttrCase<"TessLevelOuter", 11> { @@ -1769,8 +1777,8 @@ def SPIRV_BI_BaseInstance : I32EnumAttrCase<"BaseInstance", 4425> } def SPIRV_BI_DrawIndex : I32EnumAttrCase<"DrawIndex", 4426> { list availability = [ - Extension<[SPV_KHR_shader_draw_parameters, SPV_NV_mesh_shader]>, - Capability<[SPIRV_C_DrawParameters, SPIRV_C_MeshShadingNV]> + Extension<[SPV_KHR_shader_draw_parameters, SPV_NV_mesh_shader, SPV_EXT_mesh_shader]>, + Capability<[SPIRV_C_DrawParameters, SPIRV_C_MeshShadingNV, SPIRV_C_MeshShadingEXT]> ]; } def SPIRV_BI_PrimitiveShadingRateKHR : I32EnumAttrCase<"PrimitiveShadingRateKHR", 4432> { @@ -1946,6 +1954,30 @@ def SPIRV_BI_FragInvocationCountEXT : I32EnumAttrCase<"FragInvocationCountE Capability<[SPIRV_C_FragmentDensityEXT]> ]; } +def SPIRV_BI_PrimitivePointIndicesEXT : I32EnumAttrCase<"PrimitivePointIndicesEXT", 5294> { + list availability = [ + Extension<[SPV_EXT_mesh_shader]>, + Capability<[SPIRV_C_MeshShadingEXT]> + ]; +} +def SPIRV_BI_PrimitiveLineIndicesEXT : I32EnumAttrCase<"PrimitiveLineIndicesEXT", 5295> { + list availability = [ + Extension<[SPV_EXT_mesh_shader]>, + Capability<[SPIRV_C_MeshShadingEXT]> + ]; +} +def SPIRV_BI_PrimitiveTriangleIndicesEXT : I32EnumAttrCase<"PrimitiveTriangleIndicesEXT", 5296> { + list availability = [ + Extension<[SPV_EXT_mesh_shader]>, + Capability<[SPIRV_C_MeshShadingEXT]> + ]; +} +def SPIRV_BI_CullPrimitiveEXT : I32EnumAttrCase<"CullPrimitiveEXT", 5299> { + list availability = [ + Extension<[SPV_EXT_mesh_shader]>, + Capability<[SPIRV_C_MeshShadingEXT]> + ]; +} def SPIRV_BI_LaunchIdKHR : I32EnumAttrCase<"LaunchIdKHR", 5319> { list availability = [ Extension<[SPV_KHR_ray_tracing, SPV_NV_ray_tracing]>, @@ -2102,7 +2134,9 @@ def SPIRV_BuiltInAttr : SPIRV_BI_ClipDistancePerViewNV, SPIRV_BI_CullDistancePerViewNV, SPIRV_BI_LayerPerViewNV, SPIRV_BI_MeshViewCountNV, SPIRV_BI_MeshViewIndicesNV, SPIRV_BI_BaryCoordKHR, SPIRV_BI_BaryCoordNoPerspKHR, SPIRV_BI_FragSizeEXT, - SPIRV_BI_FragInvocationCountEXT, SPIRV_BI_LaunchIdKHR, SPIRV_BI_LaunchSizeKHR, + SPIRV_BI_FragInvocationCountEXT, SPIRV_BI_PrimitivePointIndicesEXT, + SPIRV_BI_PrimitiveLineIndicesEXT, SPIRV_BI_PrimitiveTriangleIndicesEXT, + SPIRV_BI_CullPrimitiveEXT, SPIRV_BI_LaunchIdKHR, SPIRV_BI_LaunchSizeKHR, SPIRV_BI_WorldRayOriginKHR, SPIRV_BI_WorldRayDirectionKHR, SPIRV_BI_ObjectRayOriginKHR, SPIRV_BI_ObjectRayDirectionKHR, SPIRV_BI_RayTminKHR, SPIRV_BI_RayTmaxKHR, SPIRV_BI_InstanceCustomIndexKHR, SPIRV_BI_ObjectToWorldKHR, @@ -2358,10 +2392,10 @@ def SPIRV_D_SecondaryViewportRelativeNV : I32EnumAttrCase<"SecondaryViewp Capability<[SPIRV_C_ShaderStereoViewNV]> ]; } -def SPIRV_D_PerPrimitiveNV : I32EnumAttrCase<"PerPrimitiveNV", 5271> { +def SPIRV_D_PerPrimitiveEXT : I32EnumAttrCase<"PerPrimitiveEXT", 5271> { list availability = [ - Extension<[SPV_NV_mesh_shader]>, - Capability<[SPIRV_C_MeshShadingNV]> + Extension<[SPV_NV_mesh_shader, SPV_EXT_mesh_shader]>, + Capability<[SPIRV_C_MeshShadingNV, SPIRV_C_MeshShadingEXT]> ]; } def SPIRV_D_PerViewNV : I32EnumAttrCase<"PerViewNV", 5272> { @@ -2660,7 +2694,7 @@ def SPIRV_DecorationAttr : SPIRV_D_AlignmentId, SPIRV_D_MaxByteOffsetId, SPIRV_D_NoSignedWrap, SPIRV_D_NoUnsignedWrap, SPIRV_D_ExplicitInterpAMD, SPIRV_D_OverrideCoverageNV, SPIRV_D_PassthroughNV, SPIRV_D_ViewportRelativeNV, - SPIRV_D_SecondaryViewportRelativeNV, SPIRV_D_PerPrimitiveNV, SPIRV_D_PerViewNV, + SPIRV_D_SecondaryViewportRelativeNV, SPIRV_D_PerPrimitiveEXT, SPIRV_D_PerViewNV, SPIRV_D_PerTaskNV, SPIRV_D_PerVertexKHR, SPIRV_D_NonUniform, SPIRV_D_RestrictPointer, SPIRV_D_AliasedPointer, SPIRV_D_BindlessSamplerNV, SPIRV_D_BindlessImageNV, SPIRV_D_BoundSamplerNV, SPIRV_D_BoundImageNV, SPIRV_D_SIMTCallINTEL, @@ -2843,12 +2877,12 @@ def SPIRV_EM_Isolines : I32EnumAttrCase<"Isolines", 25> } def SPIRV_EM_OutputVertices : I32EnumAttrCase<"OutputVertices", 26> { list availability = [ - Capability<[SPIRV_C_Geometry, SPIRV_C_MeshShadingNV, SPIRV_C_Tessellation]> + Capability<[SPIRV_C_Geometry, SPIRV_C_MeshShadingNV, SPIRV_C_MeshShadingEXT, SPIRV_C_Tessellation]> ]; } def SPIRV_EM_OutputPoints : I32EnumAttrCase<"OutputPoints", 27> { list availability = [ - Capability<[SPIRV_C_Geometry, SPIRV_C_MeshShadingNV]> + Capability<[SPIRV_C_Geometry, SPIRV_C_MeshShadingNV, SPIRV_C_MeshShadingEXT]> ]; } def SPIRV_EM_OutputLineStrip : I32EnumAttrCase<"OutputLineStrip", 28> { @@ -3002,16 +3036,16 @@ def SPIRV_EM_StencilRefLessBackAMD : I32EnumAttrCase<"StencilRefLessB Capability<[SPIRV_C_StencilExportEXT]> ]; } -def SPIRV_EM_OutputLinesNV : I32EnumAttrCase<"OutputLinesNV", 5269> { +def SPIRV_EM_OutputLinesEXT : I32EnumAttrCase<"OutputLinesEXT", 5269> { list availability = [ - Extension<[SPV_NV_mesh_shader]>, - Capability<[SPIRV_C_MeshShadingNV]> + Extension<[SPV_NV_mesh_shader, SPV_EXT_mesh_shader]>, + Capability<[SPIRV_C_MeshShadingNV, SPIRV_C_MeshShadingEXT]> ]; } -def SPIRV_EM_OutputPrimitivesNV : I32EnumAttrCase<"OutputPrimitivesNV", 5270> { +def SPIRV_EM_OutputPrimitivesEXT : I32EnumAttrCase<"OutputPrimitivesEXT", 5270> { list availability = [ - Extension<[SPV_NV_mesh_shader]>, - Capability<[SPIRV_C_MeshShadingNV]> + Extension<[SPV_NV_mesh_shader, SPV_EXT_mesh_shader]>, + Capability<[SPIRV_C_MeshShadingNV, SPIRV_C_MeshShadingEXT]> ]; } def SPIRV_EM_DerivativeGroupQuadsNV : I32EnumAttrCase<"DerivativeGroupQuadsNV", 5289> { @@ -3026,10 +3060,10 @@ def SPIRV_EM_DerivativeGroupLinearNV : I32EnumAttrCase<"DerivativeGroup Capability<[SPIRV_C_ComputeDerivativeGroupLinearNV]> ]; } -def SPIRV_EM_OutputTrianglesNV : I32EnumAttrCase<"OutputTrianglesNV", 5298> { +def SPIRV_EM_OutputTrianglesEXT : I32EnumAttrCase<"OutputTrianglesEXT", 5298> { list availability = [ - Extension<[SPV_NV_mesh_shader]>, - Capability<[SPIRV_C_MeshShadingNV]> + Extension<[SPV_NV_mesh_shader, SPV_EXT_mesh_shader]>, + Capability<[SPIRV_C_MeshShadingNV, SPIRV_C_MeshShadingEXT]> ]; } def SPIRV_EM_PixelInterlockOrderedEXT : I32EnumAttrCase<"PixelInterlockOrderedEXT", 5366> { @@ -3154,9 +3188,9 @@ def SPIRV_ExecutionModeAttr : SPIRV_EM_StencilRefReplacingEXT, SPIRV_EM_StencilRefUnchangedFrontAMD, SPIRV_EM_StencilRefGreaterFrontAMD, SPIRV_EM_StencilRefLessFrontAMD, SPIRV_EM_StencilRefUnchangedBackAMD, SPIRV_EM_StencilRefGreaterBackAMD, - SPIRV_EM_StencilRefLessBackAMD, SPIRV_EM_OutputLinesNV, SPIRV_EM_OutputPrimitivesNV, - SPIRV_EM_DerivativeGroupQuadsNV, SPIRV_EM_DerivativeGroupLinearNV, - SPIRV_EM_OutputTrianglesNV, SPIRV_EM_PixelInterlockOrderedEXT, + SPIRV_EM_StencilRefLessBackAMD, SPIRV_EM_OutputLinesEXT, + SPIRV_EM_OutputPrimitivesEXT, SPIRV_EM_DerivativeGroupQuadsNV, SPIRV_EM_DerivativeGroupLinearNV, + SPIRV_EM_OutputTrianglesEXT, SPIRV_EM_PixelInterlockOrderedEXT, SPIRV_EM_PixelInterlockUnorderedEXT, SPIRV_EM_SampleInterlockOrderedEXT, SPIRV_EM_SampleInterlockUnorderedEXT, SPIRV_EM_ShadingRateInterlockOrderedEXT, SPIRV_EM_ShadingRateInterlockUnorderedEXT, SPIRV_EM_SharedLocalMemorySizeINTEL, @@ -3243,13 +3277,24 @@ def SPIRV_EM_CallableKHR : I32EnumAttrCase<"CallableKHR", 5318> { Capability<[SPIRV_C_RayTracingKHR, SPIRV_C_RayTracingNV]> ]; } +def SPIRV_EM_TaskEXT : I32EnumAttrCase<"TaskEXT", 5364> { + list availability = [ + Capability<[SPIRV_C_MeshShadingEXT]> + ]; +} +def SPIRV_EM_MeshEXT : I32EnumAttrCase<"MeshEXT", 5365> { + list availability = [ + Capability<[SPIRV_C_MeshShadingEXT]> + ]; +} def SPIRV_ExecutionModelAttr : SPIRV_I32EnumAttr<"ExecutionModel", "valid SPIR-V ExecutionModel", "execution_model", [ SPIRV_EM_Vertex, SPIRV_EM_TessellationControl, SPIRV_EM_TessellationEvaluation, SPIRV_EM_Geometry, SPIRV_EM_Fragment, SPIRV_EM_GLCompute, SPIRV_EM_Kernel, SPIRV_EM_TaskNV, SPIRV_EM_MeshNV, SPIRV_EM_RayGenerationKHR, SPIRV_EM_IntersectionKHR, - SPIRV_EM_AnyHitKHR, SPIRV_EM_ClosestHitKHR, SPIRV_EM_MissKHR, SPIRV_EM_CallableKHR + SPIRV_EM_AnyHitKHR, SPIRV_EM_ClosestHitKHR, SPIRV_EM_MissKHR, SPIRV_EM_CallableKHR, + SPIRV_EM_TaskEXT, SPIRV_EM_MeshEXT ]>; def SPIRV_FC_None : I32BitEnumAttrCaseNone<"None">; @@ -3982,6 +4027,13 @@ def SPIRV_SC_PhysicalStorageBuffer : I32EnumAttrCase<"PhysicalStorageBuffer", Capability<[SPIRV_C_PhysicalStorageBufferAddresses]> ]; } +def SPIRV_SC_TaskPayloadWorkgroupEXT : I32EnumAttrCase<"TaskPayloadWorkgroupEXT", 5402> { + list availability = [ + MinVersion, + Extension<[SPV_EXT_mesh_shader]>, + Capability<[SPIRV_C_MeshShadingEXT]> + ]; +} def SPIRV_SC_CodeSectionINTEL : I32EnumAttrCase<"CodeSectionINTEL", 5605> { list availability = [ Extension<[SPV_INTEL_function_pointers]>, @@ -4009,7 +4061,8 @@ def SPIRV_StorageClassAttr : SPIRV_SC_StorageBuffer, SPIRV_SC_CallableDataKHR, SPIRV_SC_IncomingCallableDataKHR, SPIRV_SC_RayPayloadKHR, SPIRV_SC_HitAttributeKHR, SPIRV_SC_IncomingRayPayloadKHR, SPIRV_SC_ShaderRecordBufferKHR, SPIRV_SC_PhysicalStorageBuffer, - SPIRV_SC_CodeSectionINTEL, SPIRV_SC_DeviceOnlyINTEL, SPIRV_SC_HostOnlyINTEL + SPIRV_SC_TaskPayloadWorkgroupEXT, SPIRV_SC_CodeSectionINTEL, + SPIRV_SC_DeviceOnlyINTEL, SPIRV_SC_HostOnlyINTEL ]>; def SPIRV_PVF_PackedVectorFormat4x8Bit : I32EnumAttrCase<"PackedVectorFormat4x8Bit", 0> { @@ -4524,6 +4577,8 @@ def SPIRV_OC_OpCooperativeMatrixLoadKHR : I32EnumAttrCase<"OpCooperativeMat def SPIRV_OC_OpCooperativeMatrixStoreKHR : I32EnumAttrCase<"OpCooperativeMatrixStoreKHR", 4458>; def SPIRV_OC_OpCooperativeMatrixMulAddKHR : I32EnumAttrCase<"OpCooperativeMatrixMulAddKHR", 4459>; def SPIRV_OC_OpCooperativeMatrixLengthKHR : I32EnumAttrCase<"OpCooperativeMatrixLengthKHR", 4460>; +def SPIRV_OC_OpEmitMeshTasksEXT : I32EnumAttrCase<"OpEmitMeshTasksEXT", 5294>; +def SPIRV_OC_OpSetMeshOutputsEXT : I32EnumAttrCase<"OpSetMeshOutputsEXT", 5295>; def SPIRV_OC_OpSubgroupBlockReadINTEL : I32EnumAttrCase<"OpSubgroupBlockReadINTEL", 5575>; def SPIRV_OC_OpSubgroupBlockWriteINTEL : I32EnumAttrCase<"OpSubgroupBlockWriteINTEL", 5576>; def SPIRV_OC_OpAssumeTrueKHR : I32EnumAttrCase<"OpAssumeTrueKHR", 5630>; @@ -4622,7 +4677,8 @@ def SPIRV_OpcodeAttr : SPIRV_OC_OpUDotAccSat, SPIRV_OC_OpSUDotAccSat, SPIRV_OC_OpTypeCooperativeMatrixKHR, SPIRV_OC_OpCooperativeMatrixLoadKHR, SPIRV_OC_OpCooperativeMatrixStoreKHR, SPIRV_OC_OpCooperativeMatrixMulAddKHR, - SPIRV_OC_OpCooperativeMatrixLengthKHR, SPIRV_OC_OpSubgroupBlockReadINTEL, + SPIRV_OC_OpCooperativeMatrixLengthKHR, SPIRV_OC_OpEmitMeshTasksEXT, + SPIRV_OC_OpSetMeshOutputsEXT, SPIRV_OC_OpSubgroupBlockReadINTEL, SPIRV_OC_OpSubgroupBlockWriteINTEL, SPIRV_OC_OpAssumeTrueKHR, SPIRV_OC_OpAtomicFAddEXT, SPIRV_OC_OpConvertFToBF16INTEL, SPIRV_OC_OpConvertBF16ToFINTEL, SPIRV_OC_OpControlBarrierArriveINTEL, diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMeshOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMeshOps.td new file mode 100755 index 0000000000000..a2e3d0509525f --- /dev/null +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMeshOps.td @@ -0,0 +1,139 @@ +//===-- SPIRVMeshOps.td - MLIR SPIR-V Mesh Ops ------*- tablegen -*----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===------------------------------------------------------------------------------===// +// +// This file contains mesh ops for the SPIR-V dialect. It corresponds +// to the part of "3.52.25. Reserved Instructions" of the SPIR-V specification, and +// to the SPV_EXT_mesh_shader specification. +// +//===------------------------------------------------------------------------ -----===// + +#ifndef MLIR_DIALECT_SPIRV_MESH_OPS +#define MLIR_DIALECT_SPIRV_MESH_OPS + +include "mlir/Dialect/SPIRV/IR/SPIRVBase.td" + +// ----- + +def SPIRV_EXTEmitMeshTasksOp : SPIRV_ExtVendorOp<"EmitMeshTasks", [Terminator]> { + let summary = [{ + Defines the grid size of subsequent mesh shader workgroups to generate upon + completion of the task shader workgroup. + }]; + + let description = [{ + Defines the grid size of subsequent mesh shader workgroups to generate upon + completion of the task shader workgroup. + + Group Count X Y Z must each be a 32-bit unsigned integer value. They + configure the number of local workgroups in each respective dimensions for the + launch of child mesh tasks. See Vulkan API specification for more detail. + + Payload is an optional pointer to the payload structure to pass to the + generated mesh shader invocations. Payload must be the result of an OpVariable + with a storage class of TaskPayloadWorkgroupEXT. + + The arguments are taken from the first invocation in each workgroup. + Behaviour is undefined if any invocation terminates without executing this + instruction, or if any invocation executes this instruction in non-uniform + control flow. + + This instruction also serves as an OpControlBarrier instruction, and also + performs and adheres to the description and semantics of an OpControlBarrier + instruction with the Execution and Memory operands set to Workgroup and the + Semantics operand set to a combination of WorkgroupMemory and AcquireRelease. + + Ceases all further processing: Only instructions executed before + OpEmitMeshTasksEXT have observable side effects. + + This instruction must be the last instruction in a block. + + This instruction is only valid in the TaskEXT Execution Model. + + + + #### Example: + + ```mlir + spirv.EmitMeshTasksEXT %x, %y, %z : i32, i32, i32 + spirv.EmitMeshTasksEXT %x, %x, %z, %payload : i32, i32, i32, !spirv.ptr + ``` + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[SPV_EXT_mesh_shader]>, + Capability<[SPIRV_C_MeshShadingEXT]> + ]; + + let arguments = (ins + SignlessOrUnsignedIntOfWidths<[32]>:$group_count_x, + SignlessOrUnsignedIntOfWidths<[32]>:$group_count_y, + SignlessOrUnsignedIntOfWidths<[32]>:$group_count_z, + Optional:$payload + ); + + let results = (outs); + + let assemblyFormat = [{ + operands attr-dict `:` type(operands) + }]; +} + +// ----- + +def SPIRV_EXTSetMeshOutputsOp : SPIRV_ExtVendorOp<"SetMeshOutputs", []> { + let summary = [{ + Sets the actual output size of the primitives and vertices that the mesh + shader workgroup will emit upon completion. + }]; + + let description = [{ + Vertex Count must be a 32-bit unsigned integer value. It defines the array size + of per-vertex outputs. + + Primitive Count must a 32-bit unsigned integer value. It defines the array size + of per-primitive outputs. + + The arguments are taken from the first invocation in each workgroup. Behavior + is undefined if any invocation executes this instruction more than once or + under non-uniform control flow. Behavior is undefined if there is any control + flow path to an output write that is not preceded by this instruction. + + This instruction is only valid in the MeshEXT Execution Model. + + + + #### Example: + + ```mlir + spirv.SetMeshOutputsEXT %vcount, %pcount : i32, i32 + ``` + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[SPV_EXT_mesh_shader]>, + Capability<[SPIRV_C_MeshShadingEXT]> + ]; + + let arguments = (ins + SignlessOrUnsignedIntOfWidths<[32]>:$vertex_count, + SignlessOrUnsignedIntOfWidths<[32]>:$primitive_count + ); + + let results = (outs); + let hasVerifier = 0; + + let assemblyFormat = [{ + operands attr-dict `:` type(operands) + }]; +} + +#endif // MLIR_DIALECT_SPIRV_MESH_OPS diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td index ff1ca89f93b5a..0fa1bb9d5bd01 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td @@ -38,6 +38,7 @@ include "mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td" include "mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td" include "mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td" include "mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td" +include "mlir/Dialect/SPIRV/IR/SPIRVMeshOps.td" include "mlir/Dialect/SPIRV/IR/SPIRVMiscOps.td" include "mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td" include "mlir/Dialect/SPIRV/IR/SPIRVPrimitiveOps.td" diff --git a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt index 7d760e0dd8022..ae8ad5a491ff2 100644 --- a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt @@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRSPIRVDialect GroupOps.cpp IntegerDotProductOps.cpp MemoryOps.cpp + MeshOps.cpp SPIRVAttributes.cpp SPIRVCanonicalization.cpp SPIRVGLCanonicalization.cpp diff --git a/mlir/lib/Dialect/SPIRV/IR/MeshOps.cpp b/mlir/lib/Dialect/SPIRV/IR/MeshOps.cpp new file mode 100644 index 0000000000000..a04f077606224 --- /dev/null +++ b/mlir/lib/Dialect/SPIRV/IR/MeshOps.cpp @@ -0,0 +1,34 @@ +//===- MeshOps.cpp - MLIR SPIR-V Mesh Ops --------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the mesh operations in the SPIR-V dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// spirv.EXT.EmitMeshTasks +//===----------------------------------------------------------------------===// + +LogicalResult spirv::EXTEmitMeshTasksOp::verify() { + if (Value payload = getPayload()) { + // The operand definition restricts type to be SPIRV_AnyPointer, so we can + // cast here safely. + auto payloadType = cast(payload.getType()); + if (payloadType.getStorageClass() != + spirv::StorageClass::TaskPayloadWorkgroupEXT) + return emitOpError("payload must be a variable with a storage class of " + "TaskPayloadWorkgroupEXT"); + } + return success(); +} diff --git a/mlir/test/Dialect/SPIRV/IR/availability.mlir b/mlir/test/Dialect/SPIRV/IR/availability.mlir index 31a90ad0329d8..64ba8e3fc249e 100644 --- a/mlir/test/Dialect/SPIRV/IR/availability.mlir +++ b/mlir/test/Dialect/SPIRV/IR/availability.mlir @@ -255,3 +255,26 @@ func.func @end_primitive() -> () { spirv.EndPrimitive return } + +//===----------------------------------------------------------------------===// +// Mesh ops +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: emit_mesh_tasks +func.func @emit_mesh_tasks(%0 : i32) -> () { + // CHECK: min version: v1.4 + // CHECK: max version: v1.6 + // CHECK: extensions: [ [SPV_EXT_mesh_shader] ] + // CHECK: capabilities: [ [MeshShadingEXT] ] + spirv.EXT.EmitMeshTasks %0, %0, %0 : i32, i32, i32 +} + +// CHECK-LABEL: set_mesh_outputs +func.func @set_mesh_outputs(%0 : i32, %1 : i32) -> () { + // CHECK: min version: v1.4 + // CHECK: max version: v1.6 + // CHECK: extensions: [ [SPV_EXT_mesh_shader] ] + // CHECK: capabilities: [ [MeshShadingEXT] ] + spirv.EXT.SetMeshOutputs %0, %1 : i32, i32 + spirv.Return +} diff --git a/mlir/test/Dialect/SPIRV/IR/mesh-ops.mlir b/mlir/test/Dialect/SPIRV/IR/mesh-ops.mlir new file mode 100644 index 0000000000000..436f7d1c9fb15 --- /dev/null +++ b/mlir/test/Dialect/SPIRV/IR/mesh-ops.mlir @@ -0,0 +1,34 @@ +// RUN: mlir-opt --split-input-file --verify-diagnostics %s | FileCheck %s + +//===----------------------------------------------------------------------===// +// spirv.EmitMeshTasksEXT +//===----------------------------------------------------------------------===// + +func.func @emit_mesh_tasks(%0 : i32) { + // CHECK: spirv.EXT.EmitMeshTasks {{%.*}}, {{%.*}}, {{%.*}} : i32, i32, i32 + spirv.EXT.EmitMeshTasks %0, %0, %0 : i32, i32, i32 +} + +func.func @emit_mesh_tasks_payload(%0 : i32, %1 : !spirv.ptr) { + // CHECK: spirv.EXT.EmitMeshTasks {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}} : i32, i32, i32, !spirv.ptr + spirv.EXT.EmitMeshTasks %0, %0, %0, %1 : i32, i32, i32, !spirv.ptr +} + +// ----- + +func.func @emit_mesh_tasks_wrong_payload(%0 : i32, %1 : !spirv.ptr) { + // expected-error @+1 {{payload must be a variable with a storage class of TaskPayloadWorkgroupEXT}} + spirv.EXT.EmitMeshTasks %0, %0, %0, %1 : i32, i32, i32, !spirv.ptr +} + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.SetMeshOutputsEXT +//===----------------------------------------------------------------------===// + +func.func @set_mesh_outputs(%0 : i32, %1 : i32) { + // CHECK: spirv.EXT.SetMeshOutputs {{%.*}}, {{%.*}} : i32, i32 + spirv.EXT.SetMeshOutputs %0, %1 : i32, i32 + spirv.Return +} diff --git a/mlir/test/Target/SPIRV/mesh-ops.mlir b/mlir/test/Target/SPIRV/mesh-ops.mlir new file mode 100644 index 0000000000000..3b937072de04e --- /dev/null +++ b/mlir/test/Target/SPIRV/mesh-ops.mlir @@ -0,0 +1,33 @@ +// RUN: mlir-translate --no-implicit-module --split-input-file --test-spirv-roundtrip %s | FileCheck %s + +spirv.module Logical GLSL450 requires #spirv.vce { + // CHECK-LABEL: @emit_mesh_tasks + spirv.func @emit_mesh_tasks() "None" { + %0 = spirv.Constant 1 : i32 + // CHECK: spirv.EXT.EmitMeshTasks {{%.*}}, {{%.*}}, {{%.*}} : i32, i32, i32 + spirv.EXT.EmitMeshTasks %0, %0, %0 : i32, i32, i32 + } + // CHECK-LABEL: @set_mesh_outputs + spirv.func @set_mesh_outputs(%0 : i32, %1 : i32) "None" { + // CHECK: spirv.EXT.SetMeshOutputs {{%.*}}, {{%.*}} : i32, i32 + spirv.EXT.SetMeshOutputs %0, %1 : i32, i32 + spirv.Return + } + // CHECK: spirv.EntryPoint "TaskEXT" {{@.*}} + spirv.EntryPoint "TaskEXT" @emit_mesh_tasks +} + +// ----- + +spirv.module Logical GLSL450 requires #spirv.vce { + spirv.GlobalVariable @payload : !spirv.ptr + // CHECK-LABEL: @emit_mesh_tasks_payload + spirv.func @emit_mesh_tasks_payload() "None" { + %0 = spirv.Constant 1 : i32 + %1 = spirv.mlir.addressof @payload : !spirv.ptr + // CHECK: spirv.EXT.EmitMeshTasks {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}} : i32, i32, i32, !spirv.ptr + spirv.EXT.EmitMeshTasks %0, %0, %0, %1 : i32, i32, i32, !spirv.ptr + } + // CHECK: spirv.EntryPoint "TaskEXT" {{@.*}}, {{@.*}} + spirv.EntryPoint "TaskEXT" @emit_mesh_tasks_payload, @payload +}