diff --git a/external/slang-rhi b/external/slang-rhi index ff87b3ac8d7..9eb7734ab0e 160000 --- a/external/slang-rhi +++ b/external/slang-rhi @@ -1 +1 @@ -Subproject commit ff87b3ac8d76a5767260cf3ae9966a93faa09fe5 +Subproject commit 9eb7734ab0ebd4ead90b9ec0782dbb83521da164 diff --git a/include/slang.h b/include/slang.h index 906ac5bf10a..8d11b2f6ef7 100644 --- a/include/slang.h +++ b/include/slang.h @@ -857,6 +857,23 @@ typedef uint32_t SlangSizeT; SLANG_STAGE_PIXEL = SLANG_STAGE_FRAGMENT, }; + typedef SlangUInt32 SlangCooperativeMatrixUseIntegral; + enum SlangCooperativeMatrixUse : SlangCooperativeMatrixUseIntegral + { + SLANG_COOPERATIVE_MATRIX_USE_A, + SLANG_COOPERATIVE_MATRIX_USE_B, + SLANG_COOPERATIVE_MATRIX_USE_ACCUMULATOR, + }; + + typedef SlangUInt32 SlangCooperativeVectorMatrixLayoutIntegral; + enum SlangCooperativeVectorMatrixLayout : SlangCooperativeVectorMatrixLayoutIntegral + { + SLANG_COOPERATIVE_VECTOR_MATRIX_LAYOUT_ROW_MAJOR, + SLANG_COOPERATIVE_VECTOR_MATRIX_LAYOUT_COLUMN_MAJOR, + SLANG_COOPERATIVE_VECTOR_MATRIX_LAYOUT_INFERENCING_OPTIMAL, + SLANG_COOPERATIVE_VECTOR_MATRIX_LAYOUT_TRAINING_OPTIMAL, + }; + typedef SlangUInt32 SlangDebugInfoLevelIntegral; enum SlangDebugInfoLevel : SlangDebugInfoLevelIntegral { @@ -1964,7 +1981,10 @@ public: \ SLANG_SCALAR_TYPE_INT16, SLANG_SCALAR_TYPE_UINT16, SLANG_SCALAR_TYPE_INTPTR, - SLANG_SCALAR_TYPE_UINTPTR + SLANG_SCALAR_TYPE_UINTPTR, + SLANG_SCALAR_TYPE_BFLOAT16, + SLANG_SCALAR_TYPE_FLOAT_E4M3, + SLANG_SCALAR_TYPE_FLOAT_E5M2, }; // abstract decl reflection @@ -2349,6 +2369,11 @@ struct TypeReflection UInt8 = SLANG_SCALAR_TYPE_UINT8, Int16 = SLANG_SCALAR_TYPE_INT16, UInt16 = SLANG_SCALAR_TYPE_UINT16, + IntPtr = SLANG_SCALAR_TYPE_INTPTR, + UIntPtr = SLANG_SCALAR_TYPE_UINTPTR, + BFloat16 = SLANG_SCALAR_TYPE_BFLOAT16, + FloatE4M3 = SLANG_SCALAR_TYPE_FLOAT_E4M3, + FloatE5M2 = SLANG_SCALAR_TYPE_FLOAT_E5M2, }; Kind getKind() { return (Kind)spReflectionType_GetKind((SlangReflectionType*)this); } diff --git a/prelude/slang-cuda-prelude.h b/prelude/slang-cuda-prelude.h index d1ca4601d0b..827920c6c85 100644 --- a/prelude/slang-cuda-prelude.h +++ b/prelude/slang-cuda-prelude.h @@ -6461,63 +6461,6 @@ _slang_waveClusteredRotate(bool4 value, unsigned int delta, unsigned int cluster #if (OPTIX_VERSION >= 90000) -// Constexpr function to map Slang component type enum to OptiX cooperative vector element type -__host__ __device__ constexpr OptixCoopVecElemType slangToOptixComponentType(unsigned slangEnum) -{ - switch (slangEnum) - { - case 0: - return OPTIX_COOP_VEC_ELEM_TYPE_FLOAT8_E4M3; // FloatE4M3 - case 1: - return OPTIX_COOP_VEC_ELEM_TYPE_FLOAT8_E5M2; // FloatE5M2 - case 2: - return OPTIX_COOP_VEC_ELEM_TYPE_FLOAT16; // Float16 - case 3: - return OPTIX_COOP_VEC_ELEM_TYPE_FLOAT32; // Float32 - case 5: - return OPTIX_COOP_VEC_ELEM_TYPE_INT8; // SignedInt8 - case 7: - return OPTIX_COOP_VEC_ELEM_TYPE_INT32; // SignedInt32 - case 10: - return OPTIX_COOP_VEC_ELEM_TYPE_UINT8; // UnsignedInt8 - case 12: - return OPTIX_COOP_VEC_ELEM_TYPE_UINT32; // UnsignedInt32 - default: - return OPTIX_COOP_VEC_ELEM_TYPE_FLOAT32; // Default - } -} - -// Constexpr function to map Slang matrix layout enum to OptiX cooperative vector matrix layout -__host__ __device__ constexpr OptixCoopVecMatrixLayout slangToOptixMatrixLayout(unsigned slangEnum) -{ - switch (slangEnum) - { - case 0: - return OPTIX_COOP_VEC_MATRIX_LAYOUT_ROW_MAJOR; // RowMajor - case 1: - return OPTIX_COOP_VEC_MATRIX_LAYOUT_COLUMN_MAJOR; // ColumnMajor - case 2: - return OPTIX_COOP_VEC_MATRIX_LAYOUT_INFERENCING_OPTIMAL; // InferencingOptimal - case 3: - return OPTIX_COOP_VEC_MATRIX_LAYOUT_TRAINING_OPTIMAL; // TrainingOptimal - default: - return OPTIX_COOP_VEC_MATRIX_LAYOUT_ROW_MAJOR; // Default - } -} - -// Wrapper structs to maintain compatibility with existing template-based interface -template -struct SlangToOptixComponentType -{ - static constexpr OptixCoopVecElemType value = slangToOptixComponentType(SlangEnum); -}; - -template -struct SlangToOptixMatrixLayout -{ - static constexpr OptixCoopVecMatrixLayout value = slangToOptixMatrixLayout(SlangEnum); -}; - // Template trait to extract vector size from OptixCoopVec // Conditional compilation for NVRTC compatibility template @@ -6537,9 +6480,9 @@ struct OptixCoopVecTraits> template< typename VecTOut, typename VecTIn, - unsigned inputInterpretation, - unsigned matrixInterpretation, - unsigned matrixLayout> + OptixCoopVecElemType inputInterpretation, + OptixCoopVecElemType matrixInterpretation, + OptixCoopVecMatrixLayout matrixLayout> __forceinline__ __device__ VecTOut slangOptixCoopVecMatMul( const VecTIn& inputVector, CUdeviceptr matrix, @@ -6553,26 +6496,22 @@ __forceinline__ __device__ VecTOut slangOptixCoopVecMatMul( return optixCoopVecMatMul< VecTOut, VecTIn, - SlangToOptixComponentType::value, - SlangToOptixMatrixLayout::value, + inputInterpretation, + matrixLayout, false, N, K, - SlangToOptixComponentType::value>( - inputVector, - matrix, - matrixOffset, - matrixStride); + matrixInterpretation>(inputVector, matrix, matrixOffset, matrixStride); } // OptiX cooperative vector matrix multiplication wrapper (WITH bias - 6 runtime params) template< typename VecTOut, typename VecTIn, - unsigned inputInterpretation, - unsigned matrixInterpretation, - unsigned matrixLayout, - unsigned biasInterpretation> + OptixCoopVecElemType inputInterpretation, + OptixCoopVecElemType matrixInterpretation, + OptixCoopVecMatrixLayout matrixLayout, + OptixCoopVecElemType biasInterpretation> __forceinline__ __device__ VecTOut slangOptixCoopVecMatMul( const VecTIn& inputVector, CUdeviceptr matrix, @@ -6588,19 +6527,13 @@ __forceinline__ __device__ VecTOut slangOptixCoopVecMatMul( return optixCoopVecMatMul< VecTOut, VecTIn, - SlangToOptixComponentType::value, - SlangToOptixMatrixLayout::value, + inputInterpretation, + matrixLayout, false, N, K, - SlangToOptixComponentType::value, - SlangToOptixComponentType::value>( - inputVector, - matrix, - matrixOffset, - bias, - biasOffset, - matrixStride); + matrixInterpretation, + biasInterpretation>(inputVector, matrix, matrixOffset, bias, biasOffset, matrixStride); } // OptiX cooperative vector matrix multiplication wrapper (WITHOUT bias, 4 runtime params - @@ -6608,9 +6541,9 @@ __forceinline__ __device__ VecTOut slangOptixCoopVecMatMul( template< typename VecTOut, typename VecTIn, - unsigned inputInterpretation, - unsigned matrixInterpretation, - unsigned matrixLayout> + OptixCoopVecElemType inputInterpretation, + OptixCoopVecElemType matrixInterpretation, + OptixCoopVecMatrixLayout matrixLayout> __forceinline__ __device__ VecTOut slangOptixCoopVecMatMul( const VecTIn& inputVector, CUdeviceptr matrix, @@ -6624,16 +6557,12 @@ __forceinline__ __device__ VecTOut slangOptixCoopVecMatMul( return optixCoopVecMatMul< VecTOut, VecTIn, - SlangToOptixComponentType::value, - SlangToOptixMatrixLayout::value, + inputInterpretation, + matrixLayout, false, N, K, - SlangToOptixComponentType::value>( - inputVector, - matrix, - matrixOffset, - matrixStride); + matrixInterpretation>(inputVector, matrix, matrixOffset, matrixStride); } #endif // (OPTIX_VERSION >= 90000) diff --git a/source/core/slang-type-text-util.cpp b/source/core/slang-type-text-util.cpp index 006a4903bbf..b3f048822bb 100644 --- a/source/core/slang-type-text-util.cpp +++ b/source/core/slang-type-text-util.cpp @@ -25,7 +25,12 @@ namespace x(Int64, int64_t) \ x(UInt64, uint64_t) \ x(Float32, float) \ - x(Float64, double) + x(Float64, double) \ + x(IntPtr, intptr_t) \ + x(UIntPtr, uintptr_t) \ + x(BFloat16, bfloat16) \ + x(FloatE4M3, float_e4m3) \ + x(FloatE5M2, float_e5m2) // clang-format on struct ScalarTypeInfo diff --git a/source/slang-wasm/slang-wasm-bindings.cpp b/source/slang-wasm/slang-wasm-bindings.cpp index 9ef8efc187f..82a69ea7632 100644 --- a/source/slang-wasm/slang-wasm-bindings.cpp +++ b/source/slang-wasm/slang-wasm-bindings.cpp @@ -157,7 +157,12 @@ EMSCRIPTEN_BINDINGS(slang) .value("Int8", slang::TypeReflection::ScalarType::Int8) .value("UInt8", slang::TypeReflection::ScalarType::UInt8) .value("Int16", slang::TypeReflection::ScalarType::Int16) - .value("UInt16", slang::TypeReflection::ScalarType::UInt16); + .value("UInt16", slang::TypeReflection::ScalarType::UInt16) + .value("IntPtr", slang::TypeReflection::ScalarType::IntPtr) + .value("UIntPtr", slang::TypeReflection::ScalarType::UIntPtr) + .value("BFloat16", slang::TypeReflection::ScalarType::BFloat16) + .value("FloatE4M3", slang::TypeReflection::ScalarType::FloatE4M3) + .value("FloatE5M2", slang::TypeReflection::ScalarType::FloatE5M2); class_("TypeReflection") .function("getScalarType", &slang::wgsl::TypeReflection::getScalarType) diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 80a1182e0bd..cf457fda875 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -27125,9 +27125,9 @@ namespace linalg enum CoopMatMatrixUse { - MatrixA = 0, - MatrixB = 1, - MatrixAccumulator = 2, + MatrixA = $(SLANG_COOPERATIVE_MATRIX_USE_A), + MatrixB = $(SLANG_COOPERATIVE_MATRIX_USE_B), + MatrixAccumulator = $(SLANG_COOPERATIVE_MATRIX_USE_ACCUMULATOR), }; enum CoopMatMatrixLayout @@ -29099,6 +29099,22 @@ CoopMat __float_to_int_cast< // Cooperative Matrix Multiply-Accumulate // +__intrinsic_op($(kIROp_CoopMatMulAdd)) +internal CoopMat __coopMatMulAdd< + T : ICoopElement, + U : ICoopElement, + V : ICoopElement, + W : ICoopElement, + let S : MemoryScope, + let M : int, + let K : int, + let N : int +>( + CoopMat matA, + CoopMat matB, + CoopMat matC, + bool saturatingAccumulation); + /// Performs cooperative matrix multiply-accumulate operation: D = A * B + C /// This is the fundamental operation for hardware-accelerated matrix multiplication. /// All threads in the warp/subgroup cooperate to compute the result. @@ -29131,61 +29147,7 @@ CoopMat coopMatMulAdd< CoopMat matB, CoopMat matC) { - __target_switch - { - case spirv: - // https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/KHR/SPV_KHR_cooperative_matrix.asciidoc#3x-cooperative-matrix-operands - int operands = 0; // NoneKHR - if (__isSignedInt()) - { - operands |= 0x01; // MatrixASignedComponentsKHR - } - if (__isSignedInt()) - { - operands |= 0x02; // MatrixBSignedComponentsKHR - } - if (__isSignedInt()) - { - operands |= 0x04; // MatrixCSignedComponentsKHR - } - if (__isSignedInt()) - { - operands |= 0x08; // MatrixResultSignedComponentsKHR - } - if (saturatingAccumulation) - { - operands |= 0x10; // SaturatingAccumulationKHR - } - - return spirv_asm - { - result:$$CoopMat = OpCooperativeMatrixMulAddKHR $matA $matB $matC !operands; - }; - case cuda: - if (saturatingAccumulation) - __intrinsic_asm R"(Slang_CUDA_WMMA::coopMatMulAdd< - $T0::ElementType, - $T1::ElementType, - $T2::ElementType, - $TR::ElementType, - $T0::m_M, - $T0::m_N, - $T0::m_K, - true - >($0, $1, $2))"; - else - __intrinsic_asm R"(Slang_CUDA_WMMA::coopMatMulAdd< - $T0::ElementType, - $T1::ElementType, - $T2::ElementType, - $TR::ElementType, - $T0::m_M, - $T0::m_N, - $T0::m_K, - false - >($0, $1, $2))"; - - } + return __coopMatMulAdd(matA, matB, matC, saturatingAccumulation); } extension< @@ -30228,63 +30190,6 @@ static const struct { }; for(auto buffer : kByteAddressBufferCases) { $} - [mutating] - [ForceInline] - [require(hlsl, byteaddressbuffer_rw)] - void __mutMatMul( - CoopVec input, uint inputInterpretationHLSL, - $(buffer.type) matrix, uint matrixOffset, uint matrixInterpretationHLSL, - uint m, uint k, uint memoryLayoutHLSL, bool transpose, uint matrixStride) - { - __target_switch - { - case hlsl: - if (__isFloat() || __isSignedInt()) - { - if (__isFloat() || __isSignedInt()) - __intrinsic_asm "__builtin_MatVecMul($0, false, $1, false, $2, $3, $4, $5, $6, $7, $8, $9, $10)"; - else - __intrinsic_asm "__builtin_MatVecMul($0, false, $1, true, $2, $3, $4, $5, $6, $7, $8, $9, $10)"; - } - else - { - if (__isFloat() || __isSignedInt()) - __intrinsic_asm "__builtin_MatVecMul($0, true, $1, false, $2, $3, $4, $5, $6, $7, $8, $9, $10)"; - else - __intrinsic_asm "__builtin_MatVecMul($0, true, $1, true, $2, $3, $4, $5, $6, $7, $8, $9, $10)"; - } - } - } - - [mutating] - [ForceInline] - [require(hlsl, byteaddressbuffer_rw)] - void __mutMatMulAdd( - CoopVec input, uint inputInterpretationHLSL, - $(buffer.type) matrix, uint matrixOffset, uint matrixInterpretationHLSL, - $(buffer.type) bias, uint biasOffset, uint biasInterpretationHLSL, - uint m, uint k, uint memoryLayoutHLSL, bool transpose, uint matrixStride) - { - __target_switch - { - case hlsl: - if (__isFloat() || __isSignedInt()) - { - if (__isFloat() || __isSignedInt()) - __intrinsic_asm "__builtin_MatVecMulAdd($0, false, $1, false, $2, $3, $4, $5, $9, $10, $11, $12, $13, $6, $7, $8)"; - else - __intrinsic_asm "__builtin_MatVecMulAdd($0, false, $1, true, $2, $3, $4, $5, $9, $10, $11, $12, $13, $6, $7, $8)"; - } - else - { - if (__isFloat() || __isSignedInt()) - __intrinsic_asm "__builtin_MatVecMulAdd($0, true, $1, false, $2, $3, $4, $5, $9, $10, $11, $12, $13, $6, $7, $8)"; - else - __intrinsic_asm "__builtin_MatVecMulAdd($0, true, $1, true, $2, $3, $4, $5, $9, $10, $11, $12, $13, $6, $7, $8)"; - } - } - } - /// Multiply the given input Cooperative vector with the given matrix and accumulate the result into this vector. /// @param input The input Cooperative vector to multiply with the matrix. /// @param inputInterpretation Specifies how to interpret the values in the input vector (e.g. as packed values). @@ -30316,41 +30221,20 @@ $} static_assert(__isPackedInputInterpretation(inputInterpretation) || k == PackedK , "for non-packed inputInterpretation values k must be equal to the input vector length"); static_assert(!__isPackedInputInterpretation(inputInterpretation) - || k <= __inputInterpretationPackingFactor(inputInterpretation)*PackedK + || k <= __componentPackingFactor(inputInterpretation)*PackedK , "for packed inputInterpretation values k must be less than or equal to the input vector length times the packing factor"); - __target_switch - { - case hlsl: - let inputInterpretationHLSL = __getHLSLCoopVecComponentType(inputInterpretation); - let matrixInterpretationHLSL = __getHLSLCoopVecComponentType(matrixInterpretation); - let memoryLayoutHLSL = __getHLSLCoopVecMatrixLayout(memoryLayout); - This temp = this; - temp.__mutMatMul( - input, - inputInterpretationHLSL, - matrix, - matrixOffset, - matrixInterpretationHLSL, - N, - k, - memoryLayoutHLSL, - transpose, - matrixStride - ); - this.__mutAdd(temp); - default: this = this + coopVecMatMulPacked( - input, - inputInterpretation, - k, - matrix, - matrixOffset, - matrixInterpretation, - memoryLayout, - transpose, - matrixStride - ); - } + this = this + coopVecMatMulPacked( + input, + inputInterpretation, + k, + matrix, + matrixOffset, + matrixInterpretation, + memoryLayout, + transpose, + matrixStride + ); } /// Accumulate the result from a matrix multiplication between an input Cooperative vector and a matrix. @@ -30413,34 +30297,10 @@ $} static_assert(__isPackedInputInterpretation(inputInterpretation) || k == PackedK , "for non-packed inputInterpretation values k must be equal to the input vector length"); static_assert(!__isPackedInputInterpretation(inputInterpretation) - || k <= __inputInterpretationPackingFactor(inputInterpretation)*PackedK + || k <= __componentPackingFactor(inputInterpretation)*PackedK , "for packed inputInterpretation values k must be less than or equal to the input vector length times the packing factor"); - __target_switch - { - case hlsl: - let inputInterpretationHLSL = __getHLSLCoopVecComponentType(inputInterpretation); - let matrixInterpretationHLSL = __getHLSLCoopVecComponentType(matrixInterpretation); - let biasInterpretationHLSL = __getHLSLCoopVecComponentType(biasInterpretation); - let memoryLayoutHLSL = __getHLSLCoopVecMatrixLayout(memoryLayout); - This temp = this; - temp.__mutMatMulAdd( - input, - inputInterpretationHLSL, - matrix, - matrixOffset, - matrixInterpretationHLSL, - bias, - biasOffset, - biasInterpretationHLSL, - N, - k, - memoryLayoutHLSL, - transpose, - matrixStride - ); - this.__mutAdd(temp); - default: this = this + coopVecMatMulAddPacked( + this = this + coopVecMatMulAddPacked( input, inputInterpretation, k, @@ -30453,8 +30313,7 @@ $} memoryLayout, transpose, matrixStride - ); - } + ); } /// Performs matrix multiplication and accumulation with bias: this += input * matrix + bias @@ -30507,25 +30366,6 @@ $} ); } - [ForceInline] - [require(hlsl, byteaddressbuffer_rw)] - void __OuterProductAccumulate( - CoopVec b, - $(buffer.type) matrix, - int32_t matrixOffset, - uint matrixStride, - uint memoryLayout, - uint matrixInterpretation, - ) - { - __target_switch - { - case hlsl: - __intrinsic_asm "__builtin_OuterProductAccumulate($0, $1, $2, $3, $6, $5, $4)"; - } - } - - ${ } $} @@ -31060,10 +30900,10 @@ CoopVec coopVecLoadGroupshared +CoopVec __coopVecMatMul( + CoopVec input, + constexpr int inputInterpretation, + constexpr int inputInterpretationPackingFactor, + Ptr matrixPtr, + int32_t matrixOffset, + constexpr int matrixInterpretation, + constexpr int32_t k, + constexpr CoopVecMatrixLayout memoryLayout, + constexpr bool transpose, + constexpr uint matrixStride); + +__intrinsic_op($(kIROp_CoopVecMatMulAdd)) +internal __generic +CoopVec __coopVecMatMul( + CoopVec input, + constexpr int inputInterpretation, + constexpr int inputInterpretationPackingFactor, + $(buffer.type) matrix, + int32_t matrixOffset, + constexpr int matrixInterpretation, + constexpr int32_t k, + constexpr CoopVecMatrixLayout memoryLayout, + constexpr bool transpose, + constexpr uint matrixStride); + +__intrinsic_op($(kIROp_CoopVecMatMulAdd)) +internal __generic +CoopVec __coopVecMatMulAdd( + CoopVec input, + constexpr int inputInterpretation, + constexpr int inputInterpretationPackingFactor, + Ptr matrixPtr, + int32_t matrixOffset, + constexpr int matrixInterpretation, + constexpr int32_t k, + constexpr CoopVecMatrixLayout memoryLayout, + constexpr bool transpose, + constexpr uint matrixStride, + Ptr biasPtr, + int32_t biasOffset, + constexpr int biasInterpretation); + +__intrinsic_op($(kIROp_CoopVecMatMulAdd)) +internal __generic +CoopVec __coopVecMatMulAdd( + CoopVec input, + constexpr int inputInterpretation, + constexpr int inputInterpretationPackingFactor, + $(buffer.type) matrix, + int32_t matrixOffset, + constexpr int matrixInterpretation, + constexpr int32_t k, + constexpr CoopVecMatrixLayout memoryLayout, + constexpr bool transpose, + constexpr uint matrixStride, + $(buffer.type) bias, + int32_t biasOffset, + constexpr int biasInterpretation); + +__intrinsic_op($(kIROp_CoopVecOuterProductAccumulate)) +internal __generic +void __coopVecOuterProductAccumulate( + Ptr matrixPtr, + int32_t matrixOffset, + CoopVec a, + CoopVec b, + constexpr CoopVecMatrixLayout memoryLayout, + constexpr int matrixInterpretation, + constexpr uint matrixStride); + +__intrinsic_op($(kIROp_CoopVecOuterProductAccumulate)) +internal __generic +void __coopVecOuterProductAccumulate( + $(buffer.type) matrix, + int32_t matrixOffset, + CoopVec a, + CoopVec b, + constexpr CoopVecMatrixLayout memoryLayout, + constexpr int matrixInterpretation, + constexpr uint matrixStride); + +__intrinsic_op($(kIROp_CoopVecReduceSumAccumulate)) +internal __generic +void __coopVecReduceSumAccumulate( + Ptr bufferPtr, + int32_t offset, + CoopVec v); + +__intrinsic_op($(kIROp_CoopVecReduceSumAccumulate)) +internal __generic +void __coopVecReduceSumAccumulate( + $(buffer.type) buffer, + int32_t offset, + CoopVec v); + [ForceInline] [require(cooperative_vector)] [require(optix_coopvec)] __generic -CoopVec __coopVecMatMulPacked_impl( +CoopVec coopVecMatMulPacked( CoopVec input, constexpr CoopVecComponentType inputInterpretation, constexpr int k, @@ -31298,54 +31149,42 @@ CoopVec __coopVecMatMulPacked_impl( constexpr uint matrixStride ) { + static_assert(__isPackedInputInterpretation(inputInterpretation) || k == PackedK + , "for non-packed inputInterpretation values k must be equal to the input vector length"); + static_assert(!__isPackedInputInterpretation(inputInterpretation) + || k <= __componentPackingFactor(inputInterpretation)*PackedK + , "for packed inputInterpretation values k must be less than or equal to the input vector length times the packing factor"); + static_assert(__componentPackingFactor(matrixInterpretation) == 1, "can't use packed component type in matrixInterpretation"); + __target_switch { case spirv: - let m : int32_t = M; - let matrixInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(matrixInterpretation); - let inputInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(inputInterpretation); - let memoryLayoutSpirv : int32_t = __getSpvCoopVecMatrixLayout(memoryLayout); let matrixBuf = __getEquivalentStructuredBuffer(matrix); - let zero = 0; - - int operands = 0; - if (__isSignedInt()) - { - operands |= 0x02; - } - if (__isSignedInt()) - { - operands |= 0x08; - } - return spirv_asm - { - %runtimeArrayType = OpTypeRuntimeArray $$uint32_t; - %storagePointerType = OpTypePointer StorageBuffer %runtimeArrayType; - %matrixPointer:%storagePointerType = OpAccessChain $matrixBuf $zero; - result:$$CoopVec = OpCooperativeVectorMatrixMulNV $input $inputInterpretationSpirv %matrixPointer $matrixOffset $matrixInterpretationSpirv $m $k $memoryLayoutSpirv $transpose $matrixStride !operands; - }; + return __coopVecMatMul( + input, + __getCoopVecComponentScalarType(inputInterpretation), + __componentPackingFactor(inputInterpretation), + matrixBuf, + matrixOffset, + __getCoopVecComponentScalarType(matrixInterpretation), + k, + memoryLayout, + transpose, + matrixStride); case hlsl: - var ret = CoopVec(0); - let inputInterpretationHLSL = __getHLSLCoopVecComponentType(inputInterpretation); - let matrixInterpretationHLSL = __getHLSLCoopVecComponentType(matrixInterpretation); - let memoryLayoutHLSL = __getHLSLCoopVecMatrixLayout(memoryLayout); - ret.__mutMatMul( - input, - inputInterpretationHLSL, - matrix, - matrixOffset, - matrixInterpretationHLSL, - M, - k, - memoryLayoutHLSL, - transpose, - matrixStride - ); - return ret; - case optix_coopvec: - __intrinsic_asm "slangOptixCoopVecMatMul, OptixCoopVec<$[2], $[3]>, $1, $5, $6>($0, (CUdeviceptr)(&($3)), $4, $7, $8)", T, M, U, PackedK; + return __coopVecMatMul( + input, + __getCoopVecComponentScalarType(inputInterpretation), + __componentPackingFactor(inputInterpretation), + matrix, + matrixOffset, + __getCoopVecComponentScalarType(matrixInterpretation), + k, + memoryLayout, + transpose, + matrixStride); default: var result = CoopVec(0); @@ -31443,42 +31282,6 @@ CoopVec __coopVecMatMulPacked_impl( /// values to use from the packed input. // TODO: Can we ForceInline for just hlsl? the other platforms don't really // need it -[ForceInline] -[require(cooperative_vector)] -[require(optix_coopvec)] -__generic -CoopVec coopVecMatMulPacked( - CoopVec input, - constexpr CoopVecComponentType inputInterpretation, - constexpr int k, - $(buffer.type) matrix, - int32_t matrixOffset, - constexpr CoopVecComponentType matrixInterpretation, - constexpr CoopVecMatrixLayout memoryLayout, - constexpr bool transpose, - constexpr uint matrixStride -) -{ - // Static assertions to validate parameters at compile time - // These can be evaluated because this wrapper function can be inlined (no GenericAsm) - static_assert(__isPackedInputInterpretation(inputInterpretation) || k == PackedK - , "for non-packed inputInterpretation values k must be equal to the input vector length"); - static_assert(!__isPackedInputInterpretation(inputInterpretation) - || k <= __inputInterpretationPackingFactor(inputInterpretation)*PackedK - , "for packed inputInterpretation values k must be less than or equal to the input vector length times the packing factor"); - - // Delegate to implementation function (which contains GenericAsm for OptiX) - return __coopVecMatMulPacked_impl( - input, - inputInterpretation, - k, - matrix, - matrixOffset, - matrixInterpretation, - memoryLayout, - transpose, - matrixStride); -} /// Multiply a matrix with a cooperative vector. Given a M-row by K-col `matrix`, and a K-element column vector `input`, computes `matrix * input`, and /// returns a M-element vector. @@ -31528,11 +31331,10 @@ CoopVec coopVecMatMul( matrixStride); } -// Internal implementation without static_assert; github issue 8620 [ForceInline] [require(cooperative_vector)] [require(optix_coopvec)] -CoopVec __coopVecMatMulAddPacked_impl( +CoopVec coopVecMatMulAddPacked( CoopVec input, constexpr CoopVecComponentType inputInterpretation, constexpr int k, @@ -31547,61 +31349,50 @@ CoopVec __coopVecMatMulAddPacked_impl(matrix); let biasBuf = __getEquivalentStructuredBuffer(bias); - let zero = 0; - - int operands = 0; - if (__isSignedInt()) - { - operands |= 0x02; - } - if (__isSignedInt()) - { - operands |= 0x08; - } - return spirv_asm - { - %runtimeArrayType = OpTypeRuntimeArray $$uint32_t; - %storagePointerType = OpTypePointer StorageBuffer %runtimeArrayType; - %matrixPointer:%storagePointerType = OpAccessChain $matrixBuf $zero; - %biasPointer:%storagePointerType = OpAccessChain $biasBuf $zero; - result:$$CoopVec = OpCooperativeVectorMatrixMulAddNV $input $inputInterpretationSpirv %matrixPointer $matrixOffset $matrixInterpretationSpirv %biasPointer $biasOffset $biasInterpretationSpirv $m $k $memoryLayoutSpirv $transpose $matrixStride !operands; - }; + return __coopVecMatMulAdd( + input, + __getCoopVecComponentScalarType(inputInterpretation), + __componentPackingFactor(inputInterpretation), + matrixBuf, + matrixOffset, + __getCoopVecComponentScalarType(matrixInterpretation), + k, + memoryLayout, + transpose, + matrixStride, + biasBuf, + biasOffset, + __getCoopVecComponentScalarType(biasInterpretation)); case hlsl: - var ret = CoopVec(0); - let inputInterpretationHLSL = __getHLSLCoopVecComponentType(inputInterpretation); - let matrixInterpretationHLSL = __getHLSLCoopVecComponentType(matrixInterpretation); - let biasInterpretationHLSL = __getHLSLCoopVecComponentType(biasInterpretation); - let memoryLayoutHLSL = __getHLSLCoopVecMatrixLayout(memoryLayout); - ret.__mutMatMulAdd( - input, - inputInterpretationHLSL, - matrix, - matrixOffset, - matrixInterpretationHLSL, - bias, - biasOffset, - biasInterpretationHLSL, - M, - k, - memoryLayoutHLSL, - transpose, - matrixStride - ); - return ret; - case optix_coopvec: - __intrinsic_asm "slangOptixCoopVecMatMul, OptixCoopVec<$[2], $[3]>, $1, $5, $9, $8>($0, (CUdeviceptr)(&($3)), $4, (CUdeviceptr)(&($6)), $7, $11)", T, M, U, PackedK; + return __coopVecMatMulAdd( + input, + __getCoopVecComponentScalarType(inputInterpretation), + __componentPackingFactor(inputInterpretation), + matrix, + matrixOffset, + __getCoopVecComponentScalarType(matrixInterpretation), + k, + memoryLayout, + transpose, + matrixStride, + bias, + biasOffset, + __getCoopVecComponentScalarType(biasInterpretation)); default: var result = coopVecMatMulPacked( @@ -31693,31 +31484,6 @@ CoopVec __coopVecMatMulAddPacked_impl coopVecMatMulAddPacked( - CoopVec input, - constexpr CoopVecComponentType inputInterpretation, - constexpr int k, - $(buffer.type) matrix, - int32_t matrixOffset, - constexpr CoopVecComponentType matrixInterpretation, - $(buffer.type) bias, - int32_t biasOffset, - constexpr CoopVecComponentType biasInterpretation, - constexpr CoopVecMatrixLayout memoryLayout, - constexpr bool transpose, - constexpr uint matrixStride -) -{ - static_assert(__isPackedInputInterpretation(inputInterpretation) || k == PackedK - , "for non-packed inputInterpretation values k must be equal to the input vector length"); - static_assert(!__isPackedInputInterpretation(inputInterpretation) - || k <= __inputInterpretationPackingFactor(inputInterpretation)*PackedK - , "for packed inputInterpretation values k must be less than or equal to the input vector length times the packing factor"); - return __coopVecMatMulAddPacked_impl(input, inputInterpretation, k, matrix, matrixOffset, matrixInterpretation, bias, biasOffset, biasInterpretation, memoryLayout, transpose, matrixStride); -} /// Multiply a matrix with a cooperative vector and add a bias vector. /// Given a M-row by K-col `matrix`, a K-element column vector `input`, and a M-element vector `bias`, computes `matrix * input + bias`, and @@ -31822,27 +31588,30 @@ void coopVecOuterProductAccumulate(matrix); - let zero = 0; - spirv_asm - { - OpCapability CooperativeVectorTrainingNV; - %runtimeArrayType = OpTypeRuntimeArray $$uint32_t; - %storagePointerType = OpTypePointer StorageBuffer %runtimeArrayType; - %matrixPointer:%storagePointerType = OpAccessChain $matrixBuf $zero; - OpCooperativeVectorOuterProductAccumulateNV %matrixPointer $matrixOffset $a $b $memoryLayoutSpirv $matrixInterpretationSpirv $matrixStride; - }; + __coopVecOuterProductAccumulate( + matrixBuf, + matrixOffset, + a, + b, + memoryLayout, + __getCoopVecComponentScalarType(matrixInterpretation), + matrixStride); + case hlsl: case optix_coopvec: - __intrinsic_asm "optixCoopVecOuterProductAccumulate($0, $1, (CUdeviceptr)(&$2), $3, $4)"; + __coopVecOuterProductAccumulate( + matrix, + matrixOffset, + a, + b, + memoryLayout, + __getCoopVecComponentScalarType(matrixInterpretation), + matrixStride); default: for (int i = 0; i < M; ++i) { @@ -31909,6 +31678,7 @@ void coopVecOuterProductAccumulate( @@ -31919,21 +31689,12 @@ void coopVecReduceSumAccumulate( { __target_switch { - case hlsl: - __intrinsic_asm "__builtin_VectorAccumulate($0, $1, $2)"; case spirv: let bufBuf = __getEquivalentStructuredBuffer(buffer); - let zero = 0; - spirv_asm - { - OpCapability CooperativeVectorTrainingNV; - %runtimeArrayType = OpTypeRuntimeArray $$uint32_t; - %storagePointerType = OpTypePointer StorageBuffer %runtimeArrayType; - %bufferPointer:%storagePointerType = OpAccessChain $bufBuf $zero; - OpCooperativeVectorReduceSumAccumulateNV %bufferPointer $offset $v; - }; + __coopVecReduceSumAccumulate(bufBuf, offset, v); + case hlsl: case optix_coopvec: - __intrinsic_asm "optixCoopVecReduceSumAccumulate($0, (CUdeviceptr)(&$1), $2)"; + __coopVecReduceSumAccumulate(buffer, offset, v); default: for (int i = 0; i < N; ++i) { @@ -31963,53 +31724,54 @@ static const struct { for(auto buffer : kStructuredBufferCases_) { $} -// Internal implementation without static_assert; github issue 8620 -[ForceInline] -[require(spirv, cooperative_vector)] -[require(optix_coopvec)] -__generic -CoopVec __coopVecMatMulPacked_impl( +__intrinsic_op($(kIROp_CoopVecMatMulAdd)) +internal __generic +CoopVec __coopVecMatMul( CoopVec input, - constexpr CoopVecComponentType inputInterpretation, - constexpr int k, + constexpr int inputInterpretation, + constexpr int inputInterpretationPackingFactor, $(buffer.type) matrix, int32_t matrixOffset, - constexpr CoopVecComponentType matrixInterpretation, + constexpr int matrixInterpretation, + constexpr int32_t k, constexpr CoopVecMatrixLayout memoryLayout, constexpr bool transpose, - constexpr uint matrixStride -) -{ - __target_switch - { - case spirv: - let m : int32_t = M; - let matrixInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(matrixInterpretation); - let inputInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(inputInterpretation); - let memoryLayoutSpirv : int32_t = __getSpvCoopVecMatrixLayout(memoryLayout); - let zero = 0; + constexpr uint matrixStride); - // https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/KHR/SPV_KHR_cooperative_matrix.asciidoc#3x-cooperative-matrix-operands - int operands = 0; // NoneKHR - if (__isSignedInt()) - { - operands |= 0x02; // MatrixBSignedComponentsKHR - } - if (__isSignedInt()) - { - operands |= 0x08; // MatrixResultSignedComponentsKHR - } - return spirv_asm - { - %runtimeArrayType = OpTypeRuntimeArray $$IgnoredBufferElementType; - %storagePointerType = OpTypePointer StorageBuffer %runtimeArrayType; - %matrixPointer:%storagePointerType = OpAccessChain $matrix $zero; - result:$$CoopVec = OpCooperativeVectorMatrixMulNV $input $inputInterpretationSpirv %matrixPointer $matrixOffset $matrixInterpretationSpirv $m $k $memoryLayoutSpirv $transpose $matrixStride !operands; - }; - case optix_coopvec: - __intrinsic_asm "slangOptixCoopVecMatMul<$T0, $T0, $1, $5, $6>($0, (CUdeviceptr)(&($3)), $4, $8)"; - } -} +__intrinsic_op($(kIROp_CoopVecMatMulAdd)) +internal __generic +CoopVec __coopVecMatMulAdd( + CoopVec input, + constexpr int inputInterpretation, + constexpr int inputInterpretationPackingFactor, + $(buffer.type) matrix, + int32_t matrixOffset, + constexpr int matrixInterpretation, + constexpr int32_t k, + constexpr CoopVecMatrixLayout memoryLayout, + constexpr bool transpose, + constexpr uint matrixStride, + $(buffer.biasType) bias, + int32_t biasOffset, + constexpr int biasInterpretation); + +__intrinsic_op($(kIROp_CoopVecOuterProductAccumulate)) +internal __generic +void __coopVecOuterProductAccumulate( + $(buffer.type) matrix, + int32_t matrixOffset, + CoopVec a, + CoopVec b, + constexpr CoopVecMatrixLayout memoryLayout, + constexpr int matrixInterpretation, + constexpr uint matrixStride); + +__intrinsic_op($(kIROp_CoopVecReduceSumAccumulate)) +internal __generic +void __coopVecReduceSumAccumulate( + $(buffer.type) buffer, + int32_t offset, + CoopVec v); [ForceInline] [require(spirv, cooperative_vector)] @@ -32030,9 +31792,21 @@ CoopVec coopVecMatMulPacked( static_assert(__isPackedInputInterpretation(inputInterpretation) || k == PackedK , "for non-packed inputInterpretation values k must be equal to the input vector length"); static_assert(!__isPackedInputInterpretation(inputInterpretation) - || k <= __inputInterpretationPackingFactor(inputInterpretation)*PackedK + || k <= __componentPackingFactor(inputInterpretation)*PackedK , "for packed inputInterpretation values k must be less than or equal to the input vector length times the packing factor"); - return __coopVecMatMulPacked_impl(input, inputInterpretation, k, matrix, matrixOffset, matrixInterpretation, memoryLayout, transpose, matrixStride); + static_assert(__componentPackingFactor(matrixInterpretation) == 1, "can't use packed component type in matrixInterpretation"); + + return __coopVecMatMul( + input, + __getCoopVecComponentScalarType(inputInterpretation), + __componentPackingFactor(inputInterpretation), + matrix, + matrixOffset, + __getCoopVecComponentScalarType(matrixInterpretation), + k, + memoryLayout, + transpose, + matrixStride); } // specialized coopVecMatMul for non-packed inputs @@ -32065,61 +31839,6 @@ CoopVec coopVecMatMul( matrixStride); } -// Internal implementation without static_assert; github issue 8620 -[ForceInline] -[require(spirv, cooperative_vector)] -[require(optix_coopvec)] -CoopVec __coopVecMatMulAddPacked_impl( - CoopVec input, - constexpr CoopVecComponentType inputInterpretation, - constexpr int k, - $(buffer.type) matrix, - int32_t matrixOffset, - constexpr CoopVecComponentType matrixInterpretation, - $(buffer.biasType) bias, - int32_t biasOffset, - constexpr CoopVecComponentType biasInterpretation, - constexpr CoopVecMatrixLayout memoryLayout, - constexpr bool transpose, - constexpr uint matrixStride -) -{ - __target_switch - { - case spirv: - let m : int32_t = M; - let matrixInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(matrixInterpretation); - let biasInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(biasInterpretation); - let inputInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(inputInterpretation); - let memoryLayoutSpirv : int32_t = __getSpvCoopVecMatrixLayout(memoryLayout); - let zero = 0; - - // https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/KHR/SPV_KHR_cooperative_matrix.asciidoc#3x-cooperative-matrix-operands - int operands = 0; // NoneKHR - if (__isSignedInt()) - { - operands |= 0x02; // MatrixBSignedComponentsKHR - } - if (__isSignedInt()) - { - operands |= 0x08; // MatrixResultSignedComponentsKHR - } - - return spirv_asm - { - %runtimeArrayType = OpTypeRuntimeArray $$IgnoredBufferElementType; - %runtimeBiasArrayType = OpTypeRuntimeArray $$IgnoredBiasBufferElementType; - %storagePointerType = OpTypePointer StorageBuffer %runtimeArrayType; - %storageBiasPointerType = OpTypePointer StorageBuffer %runtimeBiasArrayType; - %matrixPointer:%storagePointerType = OpAccessChain $matrix $zero; - %biasPointer:%storageBiasPointerType = OpAccessChain $bias $zero; - result:$$CoopVec = OpCooperativeVectorMatrixMulAddNV $input $inputInterpretationSpirv %matrixPointer $matrixOffset $matrixInterpretationSpirv %biasPointer $biasOffset $biasInterpretationSpirv $m $k $memoryLayoutSpirv $transpose $matrixStride !operands; - }; - case optix_coopvec: - __intrinsic_asm "slangOptixCoopVecMatMul, OptixCoopVec<$[2], $[3]>, $1, $5, $9, $8>($0, (CUdeviceptr)(&($3)), $4, (CUdeviceptr)(&($6)), $7, $11)", T, M, U, PackedK; - } -} - [ForceInline] [require(spirv, cooperative_vector)] [require(optix_coopvec)] @@ -32141,11 +31860,27 @@ CoopVec coopVecMatMulAddPacked(input, inputInterpretation, k, matrix, matrixOffset, matrixInterpretation, bias, biasOffset, biasInterpretation, memoryLayout, transpose, matrixStride); + static_assert(__componentPackingFactor(matrixInterpretation) == 1, "can't use packed component type in matrixInterpretation"); + + return __coopVecMatMulAdd( + input, + __getCoopVecComponentScalarType(inputInterpretation), + __componentPackingFactor(inputInterpretation), + matrix, + matrixOffset, + __getCoopVecComponentScalarType(matrixInterpretation), + k, + memoryLayout, + transpose, + matrixStride, + bias, + biasOffset, + __getCoopVecComponentScalarType(biasInterpretation)); } + [require(spirv, cooperative_vector)] [require(optix_coopvec)] CoopVec coopVecMatMulAdd( @@ -32199,21 +31934,16 @@ void coopVecOuterProductAccumulate( + matrix, + matrixOffset, + a, + b, + memoryLayout, + __getCoopVecComponentScalarType(matrixInterpretation), + matrixStride); } [require(spirv, cooperative_vector_training)] @@ -32223,19 +31953,7 @@ void coopVecReduceSumAccumulate(buffer, offset, v); } ${ @@ -32243,6 +31961,40 @@ ${ } // buffer type loop $} +[ForceInline] +[require(spirv, cooperative_vector_training)] +__generic +void __coopVecOuterProductAccumulateFromPointer( + void* matrixPtr, + int32_t matrixOffset, + CoopVec a, + CoopVec b, + constexpr CoopVecMatrixLayout memoryLayout, + constexpr CoopVecComponentType matrixInterpretation, + constexpr uint matrixStride) +{ + static_assert(__componentPackingFactor(matrixInterpretation) == 1, "can't use packed component type in matrixInterpretation"); + + __coopVecOuterProductAccumulate( + (Ptr)matrixPtr, + matrixOffset, + a, + b, + memoryLayout, + __getCoopVecComponentScalarType(matrixInterpretation), + matrixStride); +} + +[ForceInline] +[require(spirv, cooperative_vector_training)] +__generic +void __coopVecReduceSumAccumulateFromPointer( + void* bufferPtr, + int32_t offset, + CoopVec v) +{ + __coopVecReduceSumAccumulate((Ptr)bufferPtr, offset, v); +} [require(spirv, cooperative_vector_training)] void coopVecOuterProductAccumulate( @@ -32255,17 +32007,15 @@ void coopVecOuterProductAccumulate( + matrixPtr0, + matrixOffset, + a, + b, + memoryLayout, + matrixInterpretation, + matrixStride); } [require(spirv, cooperative_vector_training)] @@ -32275,15 +32025,8 @@ void coopVecReduceSumAccumulate(bufferPtr0, offset, v); } // Pointer overloads for coopvector operations. @@ -32304,31 +32047,23 @@ CoopVec coopVecMatMulPacked( static_assert(__isPackedInputInterpretation(inputInterpretation) || k == PackedK , "for non-packed inputInterpretation values k must be equal to the input vector length"); static_assert(!__isPackedInputInterpretation(inputInterpretation) - || k <= __inputInterpretationPackingFactor(inputInterpretation)*PackedK + || k <= __componentPackingFactor(inputInterpretation)*PackedK , "for packed inputInterpretation values k must be less than or equal to the input vector length times the packing factor"); - __target_switch - { - case spirv: - let m : int32_t = M; - let matrixInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(matrixInterpretation); - let inputInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(inputInterpretation); - let memoryLayoutSpirv : int32_t = __getSpvCoopVecMatrixLayout(memoryLayout); - int operands = 0; // NoneKHR - let zero = 0; - let cvtMatPtr = (Ptr)matrixPtr; - if (__isSignedInt()) - { - operands |= 0x02; // MatrixBSignedComponentsKHR - } - if (__isSignedInt()) - { - operands |= 0x08; // MatrixResultSignedComponentsKHR - } - return spirv_asm - { - result:$$CoopVec = OpCooperativeVectorMatrixMulNV $input $inputInterpretationSpirv $cvtMatPtr $zero $matrixInterpretationSpirv $m $k $memoryLayoutSpirv $transpose $matrixStride !operands; - }; - } + static_assert(__componentPackingFactor(matrixInterpretation) == 1, "can't use packed component type in matrixInterpretation"); + + let zero = 0; + let cvtMatPtr = (Ptr)matrixPtr; + return __coopVecMatMul( + input, + __getCoopVecComponentScalarType(inputInterpretation), + __componentPackingFactor(inputInterpretation), + cvtMatPtr, + zero, + __getCoopVecComponentScalarType(matrixInterpretation), + k, + memoryLayout, + transpose, + matrixStride); } // specialized coopVecMatMul for non-packed inputs @@ -32375,34 +32110,28 @@ CoopVec coopVecMatMulAddPacked)matrixPtr; - let cvtBiasPtr = (Ptr)biasPtr; - int operands = 0; // NoneKHR - if (__isSignedInt()) - { - operands |= 0x02; // MatrixBSignedComponentsKHR - } - if (__isSignedInt()) - { - operands |= 0x08; // MatrixResultSignedComponentsKHR - } - return spirv_asm - { - result:$$CoopVec = OpCooperativeVectorMatrixMulAddNV $input $inputInterpretationSpirv $cvtMatPtr $zero $matrixInterpretationSpirv $cvtBiasPtr $zero $biasInterpretationSpirv $m $k $memoryLayoutSpirv $transpose $matrixStride !operands; - }; - } + let zero : int32_t = 0; + let cvtMatPtr = (Ptr)matrixPtr; + let cvtBiasPtr = (Ptr)biasPtr; + return __coopVecMatMulAdd( + input, + __getCoopVecComponentScalarType(inputInterpretation), + __componentPackingFactor(inputInterpretation), + cvtMatPtr, + zero, + __getCoopVecComponentScalarType(matrixInterpretation), + k, + memoryLayout, + transpose, + matrixStride, + cvtBiasPtr, + zero, + __getCoopVecComponentScalarType(biasInterpretation)); } [require(spirv, cooperative_vector)] @@ -32445,18 +32174,14 @@ void coopVecOuterProductAccumulate)matrixPtr; - spirv_asm - { - OpCapability CooperativeVectorTrainingNV; - OpCooperativeVectorOuterProductAccumulateNV $cvtMatrixPtr $zero $a $b $memoryLayoutSpirv $matrixInterpretationSpirv $matrixStride; - }; - } + __coopVecOuterProductAccumulateFromPointer( + matrixPtr, + zero, + a, + b, + memoryLayout, + matrixInterpretation, + matrixStride); } [require(spirv, cooperative_vector_training)] @@ -32466,16 +32191,7 @@ void coopVecReduceSumAccumulate( ) { let zero : int32_t = 0; - let bufferPtr = (Ptr)(buffer); - __target_switch - { - case spirv: - spirv_asm - { - OpCapability CooperativeVectorTrainingNV; - OpCooperativeVectorReduceSumAccumulateNV $bufferPtr $zero $v; - }; - } + __coopVecReduceSumAccumulateFromPointer(buffer, zero, v); } //@public: diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index 3cec47182ab..a6be4aee2d0 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -1464,6 +1464,10 @@ bool CLikeSourceEmitter::shouldFoldInstIntoUseSites(IRInst* inst) case kIROp_MetalCastToDepthTexture: case kIROp_LoadResourceDescriptorFromHeap: case kIROp_LoadSamplerDescriptorFromHeap: + case kIROp_CoopMatMulAdd: + case kIROp_CoopVecMatMulAdd: + case kIROp_CoopVecOuterProductAccumulate: + case kIROp_CoopVecReduceSumAccumulate: return false; // Always fold these in, because they are trivial @@ -3241,6 +3245,10 @@ void CLikeSourceEmitter::_emitInst(IRInst* inst) case kIROp_MetalAtomicCast: case kIROp_MetalCastToDepthTexture: case kIROp_SetOptiXPayloadRegister: + case kIROp_CoopMatMulAdd: + case kIROp_CoopVecMatMulAdd: + case kIROp_CoopVecOuterProductAccumulate: + case kIROp_CoopVecReduceSumAccumulate: emitInstStmt(inst); break; diff --git a/source/slang/slang-emit-cuda.cpp b/source/slang/slang-emit-cuda.cpp index 5e5b42a18b1..45a3427084a 100644 --- a/source/slang/slang-emit-cuda.cpp +++ b/source/slang/slang-emit-cuda.cpp @@ -11,6 +11,61 @@ namespace Slang { +static void emitUnsupportedTargetIntrinsicExpr( + CUDASourceEmitter* emitter, + IRInst* inst, + const char* operation, + SourceLoc location) +{ + emitter->getSink()->diagnose( + Diagnostics::UnsupportedTargetIntrinsic{.operation = operation, .location = location}); + emitter->getSourceWriter()->emit("("); + emitter->emitType(inst->getDataType()); + emitter->getSourceWriter()->emit("{})"); +} + +static UnownedStringSlice getOptixCoopVecComponentTypeName(int componentType) +{ + switch (componentType) + { + case SLANG_SCALAR_TYPE_FLOAT_E4M3: + return UnownedStringSlice("OPTIX_COOP_VEC_ELEM_TYPE_FLOAT8_E4M3"); + case SLANG_SCALAR_TYPE_FLOAT_E5M2: + return UnownedStringSlice("OPTIX_COOP_VEC_ELEM_TYPE_FLOAT8_E5M2"); + case SLANG_SCALAR_TYPE_FLOAT16: + return UnownedStringSlice("OPTIX_COOP_VEC_ELEM_TYPE_FLOAT16"); + case SLANG_SCALAR_TYPE_FLOAT32: + return UnownedStringSlice("OPTIX_COOP_VEC_ELEM_TYPE_FLOAT32"); + case SLANG_SCALAR_TYPE_INT8: + return UnownedStringSlice("OPTIX_COOP_VEC_ELEM_TYPE_INT8"); + case SLANG_SCALAR_TYPE_INT32: + return UnownedStringSlice("OPTIX_COOP_VEC_ELEM_TYPE_INT32"); + case SLANG_SCALAR_TYPE_UINT8: + return UnownedStringSlice("OPTIX_COOP_VEC_ELEM_TYPE_UINT8"); + case SLANG_SCALAR_TYPE_UINT32: + return UnownedStringSlice("OPTIX_COOP_VEC_ELEM_TYPE_UINT32"); + default: + return UnownedStringSlice(); + } +} + +static UnownedStringSlice getOptixCoopVecMatrixLayoutName(int matrixLayout) +{ + switch (matrixLayout) + { + case SLANG_COOPERATIVE_VECTOR_MATRIX_LAYOUT_ROW_MAJOR: + return UnownedStringSlice("OPTIX_COOP_VEC_MATRIX_LAYOUT_ROW_MAJOR"); + case SLANG_COOPERATIVE_VECTOR_MATRIX_LAYOUT_COLUMN_MAJOR: + return UnownedStringSlice("OPTIX_COOP_VEC_MATRIX_LAYOUT_COLUMN_MAJOR"); + case SLANG_COOPERATIVE_VECTOR_MATRIX_LAYOUT_INFERENCING_OPTIMAL: + return UnownedStringSlice("OPTIX_COOP_VEC_MATRIX_LAYOUT_INFERENCING_OPTIMAL"); + case SLANG_COOPERATIVE_VECTOR_MATRIX_LAYOUT_TRAINING_OPTIMAL: + return UnownedStringSlice("OPTIX_COOP_VEC_MATRIX_LAYOUT_TRAINING_OPTIMAL"); + default: + SLANG_UNEXPECTED("invalid OptiX cooperative vector matrix layout"); + } +} + static CUDAExtensionTracker::BaseTypeFlags _findBaseTypesUsed(IRModule* module) { typedef CUDAExtensionTracker::BaseTypeFlags Flags; @@ -768,6 +823,117 @@ bool CUDASourceEmitter::tryEmitInstStmtImpl(IRInst* inst) m_writer->emit(", -1);\n"); return true; } + case kIROp_CoopVecMatMulAdd: + { + if (!isOptixCoopVec) + { + getSink()->diagnose(Diagnostics::UnsupportedTargetIntrinsic{ + .operation = "cooperative vector matrix multiply-add", + .location = inst->sourceLoc}); + _emitInstAsDefaultInitializedVar(inst, inst->getDataType()); + return true; + } + + emitInstResultDecl(inst); + emitInstExpr(inst, getInfo(EmitOp::General)); + m_writer->emit(";\n"); + return true; + } + case kIROp_CoopMatMulAdd: + { + emitInstResultDecl(inst); + emitInstExpr(inst, getInfo(EmitOp::General)); + m_writer->emit(";\n"); + return true; + } + case kIROp_CoopVecOuterProductAccumulate: + { + if (!isOptixCoopVec) + { + getSink()->diagnose(Diagnostics::UnsupportedTargetIntrinsic{ + .operation = "cooperative vector outer-product accumulate", + .location = inst->sourceLoc}); + m_writer->emit("/* unsupported cooperative vector outer-product accumulate */\n"); + return true; + } + + auto outerProduct = cast(inst); + auto matrixLayout = cast(outerProduct->getMemoryLayout())->getValue(); + auto matrixInterpretation = + cast(outerProduct->getMatrixInterpretation())->getValue(); + + if (matrixLayout != SLANG_COOPERATIVE_VECTOR_MATRIX_LAYOUT_TRAINING_OPTIMAL) + { + getSink()->diagnose(Diagnostics::UnsupportedTargetIntrinsic{ + .operation = + "cooperative vector outer-product accumulate requires TrainingOptimal " + "matrix layout for OptiX", + .location = inst->sourceLoc}); + m_writer->emit("/* unsupported cooperative vector outer-product accumulate */\n"); + return true; + } + + if (matrixInterpretation != SLANG_SCALAR_TYPE_FLOAT16) + { + getSink()->diagnose(Diagnostics::UnsupportedTargetIntrinsic{ + .operation = + "cooperative vector outer-product accumulate requires Float16 matrix " + "interpretation for OptiX", + .location = inst->sourceLoc}); + m_writer->emit("/* unsupported cooperative vector outer-product accumulate */\n"); + return true; + } + + m_writer->emit("optixCoopVecOuterProductAccumulate("); + emitOperand(outerProduct->getA(), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(outerProduct->getB(), getInfo(EmitOp::General)); + m_writer->emit(", (CUdeviceptr)(&("); + emitOperand(outerProduct->getMatrixPtr(), getInfo(EmitOp::General)); + m_writer->emit(")), "); + emitOperand(outerProduct->getMatrixOffset(), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(outerProduct->getMatrixStride(), getInfo(EmitOp::General)); + m_writer->emit(");\n"); + return true; + } + case kIROp_CoopVecReduceSumAccumulate: + { + if (!isOptixCoopVec) + { + getSink()->diagnose(Diagnostics::UnsupportedTargetIntrinsic{ + .operation = "cooperative vector reduce-sum accumulate", + .location = inst->sourceLoc}); + m_writer->emit("/* unsupported cooperative vector reduce-sum accumulate */\n"); + return true; + } + + auto reduceSum = cast(inst); + auto valueType = as(reduceSum->getValue()->getDataType()); + SLANG_ASSERT(valueType); + auto valueElementType = as(valueType->getElementType()); + SLANG_ASSERT(valueElementType); + if (valueElementType->getBaseType() != BaseType::Half && + valueElementType->getBaseType() != BaseType::Float) + { + getSink()->diagnose(Diagnostics::UnsupportedTargetIntrinsic{ + .operation = + "cooperative vector reduce-sum accumulate requires Float16 or Float32 " + "vector element type for OptiX", + .location = inst->sourceLoc}); + m_writer->emit("/* unsupported cooperative vector reduce-sum accumulate */\n"); + return true; + } + + m_writer->emit("optixCoopVecReduceSumAccumulate("); + emitOperand(reduceSum->getValue(), getInfo(EmitOp::General)); + m_writer->emit(", (CUdeviceptr)(&("); + emitOperand(reduceSum->getBufferPtr(), getInfo(EmitOp::General)); + m_writer->emit(")), "); + emitOperand(reduceSum->getOffset(), getInfo(EmitOp::General)); + m_writer->emit(");\n"); + return true; + } case kIROp_SetOptiXPayloadRegister: { auto idxInst = as(inst->getOperand(0)); @@ -921,6 +1087,172 @@ bool CUDASourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu m_writer->emit(")"); return true; } + case kIROp_CoopMatMulAdd: + { + auto coopMatMulAdd = cast(inst); + auto matA = coopMatMulAdd->getMatA(); + auto matB = coopMatMulAdd->getMatB(); + auto matC = coopMatMulAdd->getMatC(); + auto saturatingAccumulation = + cast(coopMatMulAdd->getSaturatingAccumulation())->getValue(); + + m_writer->emit("Slang_CUDA_WMMA::coopMatMulAdd<"); + emitType(matA->getDataType()); + m_writer->emit("::ElementType, "); + emitType(matB->getDataType()); + m_writer->emit("::ElementType, "); + emitType(matC->getDataType()); + m_writer->emit("::ElementType, "); + emitType(coopMatMulAdd->getDataType()); + m_writer->emit("::ElementType, "); + emitType(matA->getDataType()); + m_writer->emit("::m_M, "); + emitType(matA->getDataType()); + m_writer->emit("::m_N, "); + emitType(matA->getDataType()); + m_writer->emit("::m_K, "); + m_writer->emit(saturatingAccumulation ? "true" : "false"); + m_writer->emit(">("); + emitOperand(matA, getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(matB, getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(matC, getInfo(EmitOp::General)); + m_writer->emit(")"); + return true; + } + case kIROp_CoopVecMatMulAdd: + { + // CoopVec matmul ops are always emitted as statements, so non-OptiX handling lives in + // tryEmitInstStmtImpl(). + SLANG_ASSERT(isOptixCoopVec); + + auto coopVecMatMulAdd = cast(inst); + auto inputInterpretationPackingFactor = + cast(coopVecMatMulAdd->getInputInterpretationPackingFactor())->getValue(); + auto inputInterpretation = + cast(coopVecMatMulAdd->getInputInterpretation())->getValue(); + auto matrixInterpretation = + cast(coopVecMatMulAdd->getMatrixInterpretation())->getValue(); + auto biasInterpretation = coopVecMatMulAdd->getBiasInterpretation(); + const bool hasBias = biasInterpretation != nullptr; + + if (inputInterpretationPackingFactor != 1) + { + emitUnsupportedTargetIntrinsicExpr( + this, + inst, + "cooperative vector matrix multiply-add with packed input is not implemented " + "yet", + inst->sourceLoc); + return true; + } + + auto inputInterpretationName = + getOptixCoopVecComponentTypeName((uint32_t)inputInterpretation); + if (!inputInterpretationName.getLength()) + { + emitUnsupportedTargetIntrinsicExpr( + this, + inst, + "cooperative vector matrix multiply-add with unsupported OptiX input " + "interpretation type", + inst->sourceLoc); + return true; + } + + auto matrixInterpretationName = + getOptixCoopVecComponentTypeName((uint32_t)matrixInterpretation); + if (!matrixInterpretationName.getLength()) + { + emitUnsupportedTargetIntrinsicExpr( + this, + inst, + "cooperative vector matrix multiply-add with unsupported OptiX matrix " + "interpretation type", + inst->sourceLoc); + return true; + } + + auto matrixLayout = cast(coopVecMatMulAdd->getMemoryLayout())->getValue(); + auto matrixLayoutName = getOptixCoopVecMatrixLayoutName((uint32_t)matrixLayout); + + auto transposeValue = cast(coopVecMatMulAdd->getTranspose())->getValue(); + if (transposeValue) + { + if (matrixInterpretation != SLANG_SCALAR_TYPE_FLOAT16 || + (matrixLayout != SLANG_COOPERATIVE_VECTOR_MATRIX_LAYOUT_INFERENCING_OPTIMAL && + matrixLayout != SLANG_COOPERATIVE_VECTOR_MATRIX_LAYOUT_TRAINING_OPTIMAL)) + { + emitUnsupportedTargetIntrinsicExpr( + this, + inst, + "cooperative vector matrix multiply-add with transpose requires Float16 " + "matrix interpretation and InferencingOptimal or TrainingOptimal matrix " + "layout for OptiX", + inst->sourceLoc); + return true; + } + } + + UnownedStringSlice biasInterpretationName; + if (hasBias) + { + biasInterpretationName = getOptixCoopVecComponentTypeName( + (uint32_t)cast(biasInterpretation)->getValue()); + if (!biasInterpretationName.getLength()) + { + emitUnsupportedTargetIntrinsicExpr( + this, + inst, + "cooperative vector matrix multiply-add with unsupported OptiX bias " + "interpretation type", + inst->sourceLoc); + return true; + } + } + + m_writer->emit("("); + m_writer->emit("slangOptixCoopVecMatMul<"); + emitType(inst->getDataType()); + m_writer->emit(", "); + emitType(coopVecMatMulAdd->getInput()->getDataType()); + m_writer->emit(", "); + m_writer->emit(inputInterpretationName); + m_writer->emit(", "); + m_writer->emit(matrixInterpretationName); + m_writer->emit(", "); + m_writer->emit(matrixLayoutName); + if (hasBias) + { + m_writer->emit(", "); + m_writer->emit(biasInterpretationName); + } + m_writer->emit(">(("); + emitOperand(coopVecMatMulAdd->getInput(), getInfo(EmitOp::General)); + m_writer->emit("), (CUdeviceptr)(&(("); + emitOperand(coopVecMatMulAdd->getMatrixPtr(), getInfo(EmitOp::General)); + m_writer->emit("))), "); + emitOperand(coopVecMatMulAdd->getMatrixOffset(), getInfo(EmitOp::General)); + if (hasBias) + { + m_writer->emit(", (CUdeviceptr)(&(("); + emitOperand(coopVecMatMulAdd->getBiasPtr(), getInfo(EmitOp::General)); + m_writer->emit("))), "); + emitOperand(coopVecMatMulAdd->getBiasOffset(), getInfo(EmitOp::General)); + } + else if ( + as( + coopVecMatMulAdd->getMatrixPtr()->getDataType()) == nullptr) + { + m_writer->emit(", "); + emitOperand(coopVecMatMulAdd->getTranspose(), getInfo(EmitOp::General)); + } + m_writer->emit(", "); + emitOperand(coopVecMatMulAdd->getMatrixStride(), getInfo(EmitOp::General)); + m_writer->emit("))"); + return true; + } case kIROp_MakeArray: { IRType* dataType = inst->getDataType(); @@ -1198,11 +1530,11 @@ static bool typeCheck(IROp op, uint32_t matrixUse) { switch (matrixUse) { - case 0: // matrixA - case 1: // matrixB + case SLANG_COOPERATIVE_MATRIX_USE_A: + case SLANG_COOPERATIVE_MATRIX_USE_B: return op == kIROp_UInt8Type || op == kIROp_Int8Type || op == kIROp_HalfType || op == kIROp_BFloat16Type || op == kIROp_FloatE4M3Type || op == kIROp_FloatE5M2Type; - case 2: // accumulator + case SLANG_COOPERATIVE_MATRIX_USE_ACCUMULATOR: return op == kIROp_IntType || op == kIROp_HalfType || op == kIROp_FloatType; } return false; @@ -1212,14 +1544,14 @@ static UnownedStringSlice getMatrixUseName(uint32_t matrixUse) { switch (matrixUse) { - case 0: + case SLANG_COOPERATIVE_MATRIX_USE_A: return UnownedStringSlice("Slang_CUDA_WMMA::MatrixA"); - case 1: + case SLANG_COOPERATIVE_MATRIX_USE_B: return UnownedStringSlice("Slang_CUDA_WMMA::MatrixB"); - case 2: + case SLANG_COOPERATIVE_MATRIX_USE_ACCUMULATOR: return UnownedStringSlice("Slang_CUDA_WMMA::MatrixC"); default: - return UnownedStringSlice(); + SLANG_UNEXPECTED("invalid cooperative matrix use"); } } @@ -1243,7 +1575,7 @@ inline FragmentShape computeShapeCombination(uint32_t matrixUse, uint32_t row, u { switch (matrixUse) { - case 0: // Matrix A: row=m, col=k + case SLANG_COOPERATIVE_MATRIX_USE_A: // Matrix A: row=m, col=k { // k must always be 16 if (col != 16) @@ -1263,7 +1595,7 @@ inline FragmentShape computeShapeCombination(uint32_t matrixUse, uint32_t row, u return {0, 0, 0}; // Invalid } } - case 1: // Matrix B: row=k, col=n + case SLANG_COOPERATIVE_MATRIX_USE_B: // Matrix B: row=k, col=n { // k must always be 16 if (row != 16) @@ -1283,7 +1615,7 @@ inline FragmentShape computeShapeCombination(uint32_t matrixUse, uint32_t row, u return {0, 0, 0}; // Invalid } } - case 2: // Matrix C/D: row=m, col=n + case SLANG_COOPERATIVE_MATRIX_USE_ACCUMULATOR: // Matrix C/D: row=m, col=n default: { // Check exact (m, n) combinations @@ -1319,7 +1651,9 @@ SlangResult CUDASourceEmitter::emitWMMAFragmentType( { getSink()->diagnose(Diagnostics::CooperativeMatrixUnsupportedElementType{ .elementType = typeName, - .matrixUse = matrixUse == 0 ? "A" : (matrixUse == 1 ? "B" : "C")}); + .matrixUse = matrixUse == SLANG_COOPERATIVE_MATRIX_USE_A + ? "A" + : (matrixUse == SLANG_COOPERATIVE_MATRIX_USE_B ? "B" : "C")}); SLANG_RELEASE_ASSERT(false); return SLANG_FAIL; } @@ -1332,7 +1666,9 @@ SlangResult CUDASourceEmitter::emitWMMAFragmentType( getSink()->diagnose(Diagnostics::CooperativeMatrixInvalidShape{ .rowCount = String(rowCount), .colCount = String(colCount), - .matrixUse = matrixUse == 0 ? "A" : (matrixUse == 1 ? "B" : "C")}); + .matrixUse = matrixUse == SLANG_COOPERATIVE_MATRIX_USE_A + ? "A" + : (matrixUse == SLANG_COOPERATIVE_MATRIX_USE_B ? "B" : "C")}); SLANG_RELEASE_ASSERT(false); return SLANG_FAIL; } diff --git a/source/slang/slang-emit-hlsl.cpp b/source/slang/slang-emit-hlsl.cpp index 4cf335942a5..8dc8ac3c793 100644 --- a/source/slang/slang-emit-hlsl.cpp +++ b/source/slang/slang-emit-hlsl.cpp @@ -29,6 +29,74 @@ double _slang_asdouble(uint64_t x) } )"; +static UnownedStringSlice _mapSlangCoopVecComponentTypeToHLSL( + int32_t slangValue, + IRIntegerValue inputInterpretationPackingFactor) +{ + if (inputInterpretationPackingFactor != 1) + { + switch (slangValue) + { + case SLANG_SCALAR_TYPE_INT8: + return UnownedStringSlice("SignedInt8Packed"); + case SLANG_SCALAR_TYPE_UINT8: + return UnownedStringSlice("UnsignedInt8Packed"); + default: + SLANG_UNEXPECTED( + "Unsupported packed cooperative vector input interpretation for HLSL emission"); + } + } + + switch (slangValue) + { + case SLANG_SCALAR_TYPE_FLOAT_E4M3: + return UnownedStringSlice("FloatE4M3"); + case SLANG_SCALAR_TYPE_FLOAT_E5M2: + return UnownedStringSlice("FloatE5M2"); + case SLANG_SCALAR_TYPE_FLOAT16: + return UnownedStringSlice("Float16"); + case SLANG_SCALAR_TYPE_FLOAT32: + return UnownedStringSlice("Float32"); + case SLANG_SCALAR_TYPE_FLOAT64: + return UnownedStringSlice("Float64"); + case SLANG_SCALAR_TYPE_INT8: + return UnownedStringSlice("SignedInt8"); + case SLANG_SCALAR_TYPE_INT16: + return UnownedStringSlice("SignedInt16"); + case SLANG_SCALAR_TYPE_INT32: + return UnownedStringSlice("SignedInt32"); + case SLANG_SCALAR_TYPE_INT64: + return UnownedStringSlice("SignedInt64"); + case SLANG_SCALAR_TYPE_UINT8: + return UnownedStringSlice("UnsignedInt8"); + case SLANG_SCALAR_TYPE_UINT16: + return UnownedStringSlice("UnsignedInt16"); + case SLANG_SCALAR_TYPE_UINT32: + return UnownedStringSlice("UnsignedInt32"); + case SLANG_SCALAR_TYPE_UINT64: + return UnownedStringSlice("UnsignedInt64"); + default: + SLANG_UNEXPECTED("Unsupported cooperative vector component type for HLSL emission"); + } +} + +static UnownedStringSlice _mapSlangCoopVecMatrixLayoutToHLSL(int32_t slangValue) +{ + switch (slangValue) + { + case SLANG_COOPERATIVE_VECTOR_MATRIX_LAYOUT_ROW_MAJOR: + return UnownedStringSlice("RowMajor"); + case SLANG_COOPERATIVE_VECTOR_MATRIX_LAYOUT_COLUMN_MAJOR: + return UnownedStringSlice("ColumnMajor"); + case SLANG_COOPERATIVE_VECTOR_MATRIX_LAYOUT_INFERENCING_OPTIMAL: + return UnownedStringSlice("InferencingOptimal"); + case SLANG_COOPERATIVE_VECTOR_MATRIX_LAYOUT_TRAINING_OPTIMAL: + return UnownedStringSlice("TrainingOptimal"); + default: + SLANG_UNEXPECTED("Unsupported cooperative vector matrix layout for HLSL emission"); + } +} + void HLSLSourceEmitter::_emitHLSLDecorationSingleString( const char* name, IRFunc* entryPoint, @@ -593,6 +661,40 @@ void HLSLSourceEmitter::emitEntryPointAttributesImpl( } } +void HLSLSourceEmitter::emitMappedCoopVecComponentType( + IRInst* operand, + IRInst* inputInterpretationPackingFactor) +{ + auto intLit = cast(operand); + + if (intLit->getValue() == SLANG_SCALAR_TYPE_BFLOAT16) + { + getSink()->diagnose(Diagnostics::UnsupportedTargetIntrinsic{ + .operation = "BFloat16 cooperative vector component type", + .location = operand->sourceLoc}); + m_writer->emit("0"); + return; + } + + IRIntegerValue inputInterpretationPackingFactorValue = 1; + if (inputInterpretationPackingFactor) + { + inputInterpretationPackingFactorValue = + cast(inputInterpretationPackingFactor)->getValue(); + } + + m_writer->emit(_mapSlangCoopVecComponentTypeToHLSL( + (int32_t)intLit->getValue(), + inputInterpretationPackingFactorValue)); +} + +void HLSLSourceEmitter::emitMappedCoopVecMatrixLayout(IRInst* operand) +{ + auto intLit = cast(operand); + + m_writer->emit(_mapSlangCoopVecMatrixLayoutToHLSL((int32_t)intLit->getValue())); +} + bool HLSLSourceEmitter::tryEmitInstStmtImpl(IRInst* inst) { auto diagnoseFloatAtommic = [&]() @@ -780,6 +882,104 @@ bool HLSLSourceEmitter::tryEmitInstStmtImpl(IRInst* inst) m_writer->emit(");"); return true; } + case kIROp_CoopVecMatMulAdd: + { + auto coopVecMatMulAdd = cast(inst); + auto input = coopVecMatMulAdd->getInput(); + auto matrixPtr = coopVecMatMulAdd->getMatrixPtr(); + auto matrixOffset = coopVecMatMulAdd->getMatrixOffset(); + auto matrixInterpretation = coopVecMatMulAdd->getMatrixInterpretation(); + auto biasPtr = coopVecMatMulAdd->getBiasPtr(); + auto biasOffset = coopVecMatMulAdd->getBiasOffset(); + auto biasInterpretation = coopVecMatMulAdd->getBiasInterpretation(); + auto k = coopVecMatMulAdd->getK(); + auto memoryLayout = coopVecMatMulAdd->getMemoryLayout(); + auto transpose = coopVecMatMulAdd->getTranspose(); + auto matrixStride = coopVecMatMulAdd->getMatrixStride(); + bool hasBias = biasInterpretation != nullptr; + + auto resultType = cast(inst->getDataType()); + auto inputType = cast(input->getDataType()); + + const bool outputIsUnsigned = isScalarIntegerType(resultType->getElementType()) && + !getIntTypeSigned(resultType->getElementType()); + const bool inputIsUnsigned = isScalarIntegerType(inputType->getElementType()) && + !getIntTypeSigned(inputType->getElementType()); + + emitInstResultDecl(inst); + emitType(inst->getDataType()); + m_writer->emit("(0);\n"); + + m_writer->emit(hasBias ? "__builtin_MatVecMulAdd(" : "__builtin_MatVecMul("); + m_writer->emit(getName(inst)); + m_writer->emit(outputIsUnsigned ? ", true, " : ", false, "); + emitOperand(input, getInfo(EmitOp::General)); + m_writer->emit(inputIsUnsigned ? ", true, " : ", false, "); + emitMappedCoopVecComponentType( + coopVecMatMulAdd->getInputInterpretation(), + coopVecMatMulAdd->getInputInterpretationPackingFactor()); + m_writer->emit(", "); + emitOperand(matrixPtr, getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(matrixOffset, getInfo(EmitOp::General)); + m_writer->emit(", "); + emitMappedCoopVecComponentType(matrixInterpretation); + m_writer->emit(", "); + emitOperand(resultType->getElementCount(), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(k, getInfo(EmitOp::General)); + m_writer->emit(", "); + emitMappedCoopVecMatrixLayout(memoryLayout); + m_writer->emit(", "); + emitOperand(transpose, getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(matrixStride, getInfo(EmitOp::General)); + if (hasBias) + { + m_writer->emit(", "); + emitOperand(biasPtr, getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(biasOffset, getInfo(EmitOp::General)); + m_writer->emit(", "); + emitMappedCoopVecComponentType(biasInterpretation); + } + m_writer->emit(");\n"); + return true; + } + case kIROp_CoopVecOuterProductAccumulate: + { + auto outerProduct = cast(inst); + + m_writer->emit("__builtin_OuterProductAccumulate("); + emitOperand(outerProduct->getA(), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(outerProduct->getB(), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(outerProduct->getMatrixPtr(), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(outerProduct->getMatrixOffset(), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitMappedCoopVecComponentType(outerProduct->getMatrixInterpretation()); + m_writer->emit(", "); + emitMappedCoopVecMatrixLayout(outerProduct->getMemoryLayout()); + m_writer->emit(", "); + emitOperand(outerProduct->getMatrixStride(), getInfo(EmitOp::General)); + m_writer->emit(");\n"); + return true; + } + case kIROp_CoopVecReduceSumAccumulate: + { + auto reduceSum = cast(inst); + + m_writer->emit("__builtin_VectorAccumulate("); + emitOperand(reduceSum->getValue(), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(reduceSum->getBufferPtr(), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(reduceSum->getOffset(), getInfo(EmitOp::General)); + m_writer->emit(");\n"); + return true; + } default: return false; } diff --git a/source/slang/slang-emit-hlsl.h b/source/slang/slang-emit-hlsl.h index 99a25b96b4c..2309a92c943 100644 --- a/source/slang/slang-emit-hlsl.h +++ b/source/slang/slang-emit-hlsl.h @@ -81,6 +81,11 @@ class HLSLSourceEmitter : public CLikeSourceEmitter virtual bool tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec) SLANG_OVERRIDE; virtual bool tryEmitInstStmtImpl(IRInst* inst) SLANG_OVERRIDE; virtual void emitSimpleValueImpl(IRInst* inst) SLANG_OVERRIDE; + + void emitMappedCoopVecComponentType( + IRInst* operand, + IRInst* inputInterpretationPackingFactor = nullptr); + void emitMappedCoopVecMatrixLayout(IRInst* operand); virtual void emitLoopControlDecorationImpl(IRLoopControlDecoration* decl) SLANG_OVERRIDE; virtual void emitFuncDecorationImpl(IRDecoration* decoration) SLANG_OVERRIDE; virtual void emitFuncDecorationsImpl(IRFunc* func) SLANG_OVERRIDE; diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index df02ea1d410..595409284c0 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -2334,7 +2334,8 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex static_cast(coopMatType->getColumnCount())->getValue(), builder.getIntType()), emitIntConstant( - static_cast(coopMatType->getMatrixUse())->getValue(), + mapSlangCooperativeMatrixUseToSpv( + static_cast(coopMatType->getMatrixUse())->getValue()), builder.getIntType())); } case kIROp_TensorAddressingTensorLayoutType: @@ -5025,6 +5026,18 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex case kIROp_CoopMatMapElementIFunc: result = emitCoopMatMapElementWithIFunc(parent, as(inst)); break; + case kIROp_CoopMatMulAdd: + result = emitCoopMatMulAdd(parent, inst); + break; + case kIROp_CoopVecMatMulAdd: + result = emitCoopVecMatMulAdd(parent, inst); + break; + case kIROp_CoopVecOuterProductAccumulate: + result = emitCoopVecOuterProductAccumulate(parent, inst); + break; + case kIROp_CoopVecReduceSumAccumulate: + result = emitCoopVecReduceSumAccumulate(parent, inst); + break; case kIROp_MakeTensorAddressingTensorLayout: result = emitOpCreateTensorLayout(parent, inst, getID(ensureInst(inst->getDataType()))); break; @@ -8300,6 +8313,41 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex return result; } + // Emit an operand that satisfies cooperative vector SPIRV opcodes, which require a pointer + // to an array type (not a pointer to a struct). + // + // - Buffer resource case (ByteAddressBuffer / StructuredBuffer): after SPIRV legalization + // the global param has type ptr-to-struct{runtimeArray}. Emit an OpAccessChain with + // index 0 to pierce through the wrapper struct and expose the runtime array. + // + // - Ptr case: the pointer already points directly to the unsized array, so return it + // as-is. + SpvInst* emitBufferPtrAsArrayPtr(SpvInstParent* parent, IRInst* bufferVal) + { + IRBuilder builder(bufferVal); + auto addressSpace = + isSpirv14OrLater() ? AddressSpace::StorageBuffer : AddressSpace::Uniform; + IRPtrTypeBase* bufPtrType = cast(bufferVal->getDataType()); + // If the pointee is not a struct, the pointer already targets an array directly + // (e.g. Ptr) — use it without modification. + IRStructType* bufType = as(bufPtrType->getValueType()); + if (!bufType) + return ensureInst(bufferVal); + // The struct's first (and only) field is the runtime array of elements. + IRArrayTypeBase* arrayType = + cast(bufType->getFields().getFirst()->getFieldType()); + return emitOpAccessChain( + parent, + nullptr, + builder.getPtrType( + arrayType, + AccessQualifier::ReadWrite, + addressSpace, + bufPtrType->getDataLayout()), + bufferVal, + makeArray(emitIntConstant(0, builder.getIntType()))); + } + SpvInst* emitGetBufferPtr(SpvInstParent* parent, IRInst* inst) { IRBuilder builder(inst); @@ -8982,6 +9030,269 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex }); } + bool isSignedIntegerScalarType(IRType* type) + { + return isScalarIntegerType(type) && getIntTypeSigned(type); + } + + SpvInst* emitCoopMatMulAdd(SpvInstParent* parent, IRInst* inst) + { + requireSPIRVCapability(SpvCapabilityCooperativeMatrixKHR); + ensureExtensionDeclaration(UnownedStringSlice("SPV_KHR_cooperative_matrix")); + + SLANG_ASSERT(inst->getOperandCount() == 4); + + auto coopMatMulAdd = cast(inst); + auto saturatingAccumulation = + cast(coopMatMulAdd->getSaturatingAccumulation())->getValue(); + + uint32_t operandsMask = 0; + auto aType = cast(coopMatMulAdd->getMatA()->getDataType()); + auto bType = cast(coopMatMulAdd->getMatB()->getDataType()); + auto cType = cast(coopMatMulAdd->getMatC()->getDataType()); + auto resultType = cast(inst->getDataType()); + + if (isSignedIntegerScalarType(aType->getElementType())) + operandsMask |= SpvCooperativeMatrixOperandsMatrixASignedComponentsKHRMask; + if (isSignedIntegerScalarType(bType->getElementType())) + operandsMask |= SpvCooperativeMatrixOperandsMatrixBSignedComponentsKHRMask; + if (isSignedIntegerScalarType(cType->getElementType())) + operandsMask |= SpvCooperativeMatrixOperandsMatrixCSignedComponentsKHRMask; + if (isSignedIntegerScalarType(resultType->getElementType())) + operandsMask |= SpvCooperativeMatrixOperandsMatrixResultSignedComponentsKHRMask; + + if (saturatingAccumulation) + operandsMask |= SpvCooperativeMatrixOperandsSaturatingAccumulationKHRMask; + return emitInstCustomOperandFunc( + parent, + inst, + SpvOpCooperativeMatrixMulAddKHR, + [&]() + { + emitOperand(inst->getFullType()); + emitOperand(kResultID); + emitOperand(coopMatMulAdd->getMatA()); + emitOperand(coopMatMulAdd->getMatB()); + emitOperand(coopMatMulAdd->getMatC()); + if (operandsMask) + emitOperand(SpvLiteralInteger::from32(operandsMask)); + }); + } + + IRIntegerValue mapSlangCooperativeMatrixUseToSpv(IRIntegerValue slangValue) + { + switch ((int32_t)slangValue) + { + case SLANG_COOPERATIVE_MATRIX_USE_A: + return SpvCooperativeMatrixUseMatrixAKHR; + case SLANG_COOPERATIVE_MATRIX_USE_B: + return SpvCooperativeMatrixUseMatrixBKHR; + case SLANG_COOPERATIVE_MATRIX_USE_ACCUMULATOR: + return SpvCooperativeMatrixUseMatrixAccumulatorKHR; + default: + SLANG_UNEXPECTED("Unsupported cooperative matrix use for SPIR-V emission"); + } + } + + IRIntegerValue mapSlangCoopVecMatrixLayoutToSpv(IRIntegerValue slangValue) + { + switch ((int32_t)slangValue) + { + case SLANG_COOPERATIVE_VECTOR_MATRIX_LAYOUT_ROW_MAJOR: + return SpvCooperativeVectorMatrixLayoutRowMajorNV; + case SLANG_COOPERATIVE_VECTOR_MATRIX_LAYOUT_COLUMN_MAJOR: + return SpvCooperativeVectorMatrixLayoutColumnMajorNV; + case SLANG_COOPERATIVE_VECTOR_MATRIX_LAYOUT_INFERENCING_OPTIMAL: + return SpvCooperativeVectorMatrixLayoutInferencingOptimalNV; + case SLANG_COOPERATIVE_VECTOR_MATRIX_LAYOUT_TRAINING_OPTIMAL: + return SpvCooperativeVectorMatrixLayoutTrainingOptimalNV; + default: + SLANG_UNEXPECTED("Unsupported cooperative vector matrix layout for SPIR-V emission"); + } + } + + IRIntegerValue mapSlangCoopVecComponentTypeToSpv( + IRIntegerValue slangValue, + IRIntegerValue inputInterpretationPackingFactor) + { + if (inputInterpretationPackingFactor != 1) + { + switch ((int32_t)slangValue) + { + case SLANG_SCALAR_TYPE_INT8: + return SpvComponentTypeSignedInt8PackedNV; + case SLANG_SCALAR_TYPE_UINT8: + return SpvComponentTypeUnsignedInt8PackedNV; + default: + SLANG_UNEXPECTED("Unsupported packed cooperative vector input interpretation for " + "SPIR-V emission"); + } + } + + switch ((int32_t)slangValue) + { + case SLANG_SCALAR_TYPE_FLOAT_E4M3: + return SpvComponentTypeFloatE4M3NV; + case SLANG_SCALAR_TYPE_FLOAT_E5M2: + return SpvComponentTypeFloatE5M2NV; + case SLANG_SCALAR_TYPE_FLOAT16: + return SpvComponentTypeFloat16NV; + case SLANG_SCALAR_TYPE_FLOAT32: + return SpvComponentTypeFloat32NV; + case SLANG_SCALAR_TYPE_FLOAT64: + return SpvComponentTypeFloat64NV; + case SLANG_SCALAR_TYPE_INT8: + return SpvComponentTypeSignedInt8NV; + case SLANG_SCALAR_TYPE_INT16: + return SpvComponentTypeSignedInt16NV; + case SLANG_SCALAR_TYPE_INT32: + return SpvComponentTypeSignedInt32NV; + case SLANG_SCALAR_TYPE_INT64: + return SpvComponentTypeSignedInt64NV; + case SLANG_SCALAR_TYPE_UINT8: + return SpvComponentTypeUnsignedInt8NV; + case SLANG_SCALAR_TYPE_UINT16: + return SpvComponentTypeUnsignedInt16NV; + case SLANG_SCALAR_TYPE_UINT32: + return SpvComponentTypeUnsignedInt32NV; + case SLANG_SCALAR_TYPE_UINT64: + return SpvComponentTypeUnsignedInt64NV; + default: + SLANG_UNEXPECTED("Unsupported cooperative vector component type for SPIR-V emission"); + } + } + + void emitMappedCoopVecMatrixLayoutOperand(IRInst* operand) + { + auto intLit = cast(operand); + emitOperand(emitIntConstant( + mapSlangCoopVecMatrixLayoutToSpv(intLit->getValue()), + operand->getDataType())); + } + + void emitMappedCoopVecComponentTypeOperand( + IRInst* operand, + IRInst* inputInterpretationPackingFactor = nullptr) + { + auto intLit = cast(operand); + + IRIntegerValue packingFactor = 1; + if (inputInterpretationPackingFactor) + { + packingFactor = cast(inputInterpretationPackingFactor)->getValue(); + } + + if (intLit->getValue() == SLANG_SCALAR_TYPE_BFLOAT16) + { + m_sink->diagnose(Diagnostics::UnsupportedTargetIntrinsic{ + .operation = "BFloat16 cooperative vector component type", + .location = operand->sourceLoc}); + emitOperand(emitIntConstant(0, operand->getDataType())); + return; + } + + emitOperand(emitIntConstant( + mapSlangCoopVecComponentTypeToSpv(intLit->getValue(), packingFactor), + operand->getDataType())); + } + + SpvInst* emitCoopVecMatMulAdd(SpvInstParent* parent, IRInst* inst) + { + requireSPIRVCapability(SpvCapabilityCooperativeVectorNV); + ensureExtensionDeclaration(UnownedStringSlice("SPV_NV_cooperative_vector")); + + SLANG_ASSERT(inst->getOperandCount() == 10 || inst->getOperandCount() == 13); + + auto coopVecMatMulAdd = cast(inst); + auto k = cast(coopVecMatMulAdd->getK()); + auto transpose = cast(coopVecMatMulAdd->getTranspose()); + bool hasBias = coopVecMatMulAdd->getBiasInterpretation() != nullptr; + + uint32_t operandsMask = 0; + auto inputType = cast(coopVecMatMulAdd->getInput()->getDataType()); + auto resultType = cast(inst->getDataType()); + auto resultElementCount = cast(resultType->getElementCount()); + if (isSignedIntegerScalarType(inputType->getElementType())) + operandsMask |= SpvCooperativeMatrixOperandsMatrixBSignedComponentsKHRMask; + if (isSignedIntegerScalarType(resultType->getElementType())) + operandsMask |= SpvCooperativeMatrixOperandsMatrixResultSignedComponentsKHRMask; + + return emitInstCustomOperandFunc( + parent, + inst, + hasBias ? SpvOpCooperativeVectorMatrixMulAddNV : SpvOpCooperativeVectorMatrixMulNV, + [&]() + { + emitOperand(inst->getFullType()); + emitOperand(kResultID); + emitOperand(coopVecMatMulAdd->getInput()); + emitMappedCoopVecComponentTypeOperand( + coopVecMatMulAdd->getInputInterpretation(), + coopVecMatMulAdd->getInputInterpretationPackingFactor()); + emitOperand(emitBufferPtrAsArrayPtr(parent, coopVecMatMulAdd->getMatrixPtr())); + emitOperand(coopVecMatMulAdd->getMatrixOffset()); + emitMappedCoopVecComponentTypeOperand(coopVecMatMulAdd->getMatrixInterpretation()); + if (hasBias) + { + emitOperand(emitBufferPtrAsArrayPtr(parent, coopVecMatMulAdd->getBiasPtr())); + emitOperand(coopVecMatMulAdd->getBiasOffset()); + emitMappedCoopVecComponentTypeOperand( + coopVecMatMulAdd->getBiasInterpretation()); + } + emitOperand(resultElementCount); + emitOperand(emitIntConstant(k->getValue(), k->getDataType())); + emitMappedCoopVecMatrixLayoutOperand(coopVecMatMulAdd->getMemoryLayout()); + emitOperand(transpose); + emitOperand(coopVecMatMulAdd->getMatrixStride()); + if (operandsMask) + emitOperand(SpvLiteralInteger::from32(operandsMask)); + }); + } + + SpvInst* emitCoopVecOuterProductAccumulate(SpvInstParent* parent, IRInst* inst) + { + requireSPIRVCapability(SpvCapabilityCooperativeVectorTrainingNV); + ensureExtensionDeclaration(UnownedStringSlice("SPV_NV_cooperative_vector")); + + SLANG_ASSERT(inst->getOperandCount() == 7); + + auto outerProduct = cast(inst); + + return emitInstCustomOperandFunc( + parent, + inst, + SpvOpCooperativeVectorOuterProductAccumulateNV, + [&]() + { + emitOperand(emitBufferPtrAsArrayPtr(parent, outerProduct->getMatrixPtr())); + emitOperand(outerProduct->getMatrixOffset()); + emitOperand(outerProduct->getA()); + emitOperand(outerProduct->getB()); + emitMappedCoopVecMatrixLayoutOperand(outerProduct->getMemoryLayout()); + emitMappedCoopVecComponentTypeOperand(outerProduct->getMatrixInterpretation()); + emitOperand(outerProduct->getMatrixStride()); + }); + } + + SpvInst* emitCoopVecReduceSumAccumulate(SpvInstParent* parent, IRInst* inst) + { + requireSPIRVCapability(SpvCapabilityCooperativeVectorTrainingNV); + ensureExtensionDeclaration(UnownedStringSlice("SPV_NV_cooperative_vector")); + + auto reduceSum = cast(inst); + + return emitInstCustomOperandFunc( + parent, + inst, + SpvOpCooperativeVectorReduceSumAccumulateNV, + [&]() + { + emitOperand(emitBufferPtrAsArrayPtr(parent, reduceSum->getBufferPtr())); + emitOperand(reduceSum->getOffset()); + emitOperand(reduceSum->getValue()); + }); + } + SpvInst* emitSplat(SpvInstParent* parent, IRInst* inst, IRInst* scalar, IRIntegerValue numElems) { const auto scalarTy = as(scalar->getDataType()); diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index d7eda765098..212f9ef1f68 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -2046,6 +2046,8 @@ Result linkAndOptimizeIR( validateIRModuleIfEnabled(codeGenContext, irModule); } + SLANG_PASS(validateCooperativeOperations, sink); + auto metadata = new ArtifactPostEmitMetadata; outLinkedIR.metadata = metadata; diff --git a/source/slang/slang-ir-insts-stable-names.lua b/source/slang/slang-ir-insts-stable-names.lua index fda7e8b4d69..c6d9ee1ae60 100644 --- a/source/slang/slang-ir-insts-stable-names.lua +++ b/source/slang/slang-ir-insts-stable-names.lua @@ -778,4 +778,8 @@ return { ["ShapePermute"] = 776, ["ShapeSwap"] = 777, ["ShapeReduce"] = 778, + ["CoopMatMulAdd"] = 779, + ["CoopVecMatMulAdd"] = 780, + ["CoopVecOuterProductAccumulate"] = 781, + ["CoopVecReduceSumAccumulate"] = 782, } diff --git a/source/slang/slang-ir-insts.lua b/source/slang/slang-ir-insts.lua index 9d165b157d9..cfaf87e6d11 100644 --- a/source/slang/slang-ir-insts.lua +++ b/source/slang/slang-ir-insts.lua @@ -1450,6 +1450,48 @@ local insts = { { TorchGetCudaStream = {} }, { TorchTensorGetView = {} }, { CoopMatMapElementIFunc = { min_operands = 2 } }, + { CoopMatMulAdd = { operands = { { "matA" }, { "matB" }, { "matC" }, { "saturatingAccumulation" } } } }, + { + CoopVecMatMulAdd = { + operands = { + { "input" }, + { "inputInterpretation" }, + { "inputInterpretationPackingFactor" }, + { "matrixPtr" }, + { "matrixOffset" }, + { "matrixInterpretation" }, + { "k" }, + { "memoryLayout" }, + { "transpose" }, + { "matrixStride" }, + { "biasPtr", optional = true }, + { "biasOffset", optional = true }, + { "biasInterpretation", optional = true }, + }, + }, + }, + { + CoopVecOuterProductAccumulate = { + operands = { + { "matrixPtr" }, + { "matrixOffset" }, + { "a" }, + { "b" }, + { "memoryLayout" }, + { "matrixInterpretation" }, + { "matrixStride" }, + }, + }, + }, + { + CoopVecReduceSumAccumulate = { + operands = { + { "bufferPtr" }, + { "offset" }, + { "value" }, + }, + }, + }, { allocateOpaqueHandle = {} }, { BindingQuery = { diff --git a/source/slang/slang-ir-validate.cpp b/source/slang/slang-ir-validate.cpp index 9f7f818a2df..82937c23692 100644 --- a/source/slang/slang-ir-validate.cpp +++ b/source/slang/slang-ir-validate.cpp @@ -791,4 +791,411 @@ void validateAtomicOperations(IRModule* module, bool skipFuncParamValidation, Di validateAtomicOperations(skipFuncParamValidation, sink, module->getModuleInst()); } +// +// Cooperative matrix/vector validation. +// + +static void validateCoopMatMulAdd(IRInst* inst, DiagnosticSink* sink) +{ + auto emitDiagnostic = [=](const char* message) + { + sink->diagnose(Diagnostics::IrValidationFailed{ + .message = message, + .location = inst->sourceLoc, + }); + }; + + if (inst->getOperandCount() != 4) + { + emitDiagnostic("Malformed CoopMatMulAdd operand list"); + return; + } + + auto coopMatMulAdd = as(inst); + SLANG_ASSERT(coopMatMulAdd); + + auto aType = as(coopMatMulAdd->getMatA()->getDataType()); + auto bType = as(coopMatMulAdd->getMatB()->getDataType()); + auto cType = as(coopMatMulAdd->getMatC()->getDataType()); + auto resultType = as(inst->getDataType()); + if (!aType || !bType || !cType || !resultType) + { + emitDiagnostic( + "CoopMatMulAdd input and result operands must have cooperative matrix types"); + return; + } + + if (!as(coopMatMulAdd->getSaturatingAccumulation())) + { + emitDiagnostic("CoopMatMulAdd saturatingAccumulation operand must be a bool literal"); + return; + } + + auto aUse = as(aType->getMatrixUse()); + auto bUse = as(bType->getMatrixUse()); + auto cUse = as(cType->getMatrixUse()); + auto resultUse = as(resultType->getMatrixUse()); + if (!aUse || !bUse || !cUse || !resultUse) + { + emitDiagnostic("CoopMatMulAdd cooperative matrix uses must be integer literals"); + return; + } + + if (aUse->getValue() != SLANG_COOPERATIVE_MATRIX_USE_A || + bUse->getValue() != SLANG_COOPERATIVE_MATRIX_USE_B || + cUse->getValue() != SLANG_COOPERATIVE_MATRIX_USE_ACCUMULATOR || + resultUse->getValue() != SLANG_COOPERATIVE_MATRIX_USE_ACCUMULATOR) + { + emitDiagnostic("CoopMatMulAdd requires MatrixA, MatrixB, and accumulator matrix " + "operands"); + return; + } + + auto aScope = as(aType->getScope()); + auto bScope = as(bType->getScope()); + auto cScope = as(cType->getScope()); + auto resultScope = as(resultType->getScope()); + if (!aScope || !bScope || !cScope || !resultScope) + { + emitDiagnostic("CoopMatMulAdd scopes must be integer literals"); + return; + } + + if (aScope->getValue() != bScope->getValue() || aScope->getValue() != cScope->getValue() || + aScope->getValue() != resultScope->getValue()) + { + emitDiagnostic( + "CoopMatMulAdd requires all cooperative matrix operands to use the same scope"); + return; + } + + auto aRows = as(aType->getRowCount()); + auto aCols = as(aType->getColumnCount()); + auto bRows = as(bType->getRowCount()); + auto bCols = as(bType->getColumnCount()); + auto cRows = as(cType->getRowCount()); + auto cCols = as(cType->getColumnCount()); + auto resultRows = as(resultType->getRowCount()); + auto resultCols = as(resultType->getColumnCount()); + if (!aRows || !aCols || !bRows || !bCols || !cRows || !cCols || !resultRows || !resultCols) + { + emitDiagnostic("CoopMatMulAdd row and column counts must be integer literals"); + return; + } + + if (aRows->getValue() != cRows->getValue() || aRows->getValue() != resultRows->getValue() || + aCols->getValue() != bRows->getValue() || bCols->getValue() != cCols->getValue() || + bCols->getValue() != resultCols->getValue()) + { + emitDiagnostic("CoopMatMulAdd operand dimensions must satisfy A(MxK), B(KxN), and " + "C/result(MxN)"); + return; + } +} + +static bool isValidInputInterpretation( + IRIntegerValue inputInterpretation, + IRIntegerValue packingFactor) +{ + switch ((int32_t)inputInterpretation) + { + case SLANG_SCALAR_TYPE_INT8: + case SLANG_SCALAR_TYPE_UINT8: + return packingFactor == 1 || packingFactor == 4; + + default: + return packingFactor == 1; + } +} + +static bool isValidCoopVecDataOperand(IRInst* operand) +{ + auto type = operand->getDataType(); + if (auto rateQualifiedType = as(type)) + type = rateQualifiedType->getValueType(); + + type = unwrapArray(type); + return as(type) || as(type) || + as(type); +} + +static bool isValidCoopVecAccumulationOperand(IRInst* operand) +{ + auto type = operand->getDataType(); + if (auto rateQualifiedType = as(type)) + type = rateQualifiedType->getValueType(); + + type = unwrapArray(type); + + if (as(type)) + return true; + + switch (type->getOp()) + { + case kIROp_HLSLRWStructuredBufferType: + case kIROp_HLSLRasterizerOrderedStructuredBufferType: + case kIROp_HLSLRWByteAddressBufferType: + case kIROp_HLSLRasterizerOrderedByteAddressBufferType: + return true; + default: + return false; + } +} + +static void validateCoopVecMatMulAdd(IRInst* inst, DiagnosticSink* sink) +{ + auto emitDiagnostic = [=](const char* message) + { + sink->diagnose(Diagnostics::IrValidationFailed{ + .message = message, + .location = inst->sourceLoc, + }); + }; + + if (inst->getOperandCount() != 10 && inst->getOperandCount() != 13) + { + emitDiagnostic("Malformed CoopVecMatMulAdd operand list"); + return; + } + + auto coopVecMatMulAdd = as(inst); + SLANG_ASSERT(coopVecMatMulAdd); + const bool hasBias = inst->getOperandCount() == 13; + + auto resultType = as(inst->getDataType()); + auto inputType = as(coopVecMatMulAdd->getInput()->getDataType()); + if (!resultType || !inputType) + { + emitDiagnostic( + "CoopVecMatMulAdd input and result operands must have cooperative vector types"); + return; + } + + auto inputInterpretation = as(coopVecMatMulAdd->getInputInterpretation()); + if (!inputInterpretation) + { + emitDiagnostic("CoopVecMatMulAdd inputInterpretation operand must be an integer literal"); + return; + } + + if (!as(coopVecMatMulAdd->getMatrixInterpretation())) + { + emitDiagnostic("CoopVecMatMulAdd matrixInterpretation operand must be an integer literal"); + return; + } + + if (!isValidCoopVecDataOperand(coopVecMatMulAdd->getMatrixPtr())) + { + emitDiagnostic("CoopVecMatMulAdd matrix operand must be a pointer, ByteAddressBuffer, or " + "StructuredBuffer"); + return; + } + + if (hasBias) + { + if (!as(coopVecMatMulAdd->getBiasInterpretation())) + { + emitDiagnostic( + "CoopVecMatMulAdd biasInterpretation operand must be an integer literal"); + return; + } + + if (!isValidCoopVecDataOperand(coopVecMatMulAdd->getBiasPtr())) + { + emitDiagnostic("CoopVecMatMulAdd bias operand must be a pointer, ByteAddressBuffer, or " + "StructuredBuffer"); + return; + } + } + + if (!as(coopVecMatMulAdd->getMemoryLayout())) + { + emitDiagnostic("CoopVecMatMulAdd memoryLayout operand must be an integer literal"); + return; + } + + if (!as(coopVecMatMulAdd->getTranspose())) + { + emitDiagnostic("CoopVecMatMulAdd transpose operand must be a bool literal"); + return; + } + + auto k = as(coopVecMatMulAdd->getK()); + if (!k) + { + emitDiagnostic("CoopVecMatMulAdd k operand must be an integer literal"); + return; + } + + auto packingFactor = as(coopVecMatMulAdd->getInputInterpretationPackingFactor()); + if (!packingFactor) + { + emitDiagnostic( + "CoopVecMatMulAdd inputInterpretationPackingFactor operand must be an integer literal"); + return; + } + + auto inputElementCount = as(inputType->getElementCount()); + if (!inputElementCount) + { + emitDiagnostic("CoopVecMatMulAdd input element count must be known at compile time"); + return; + } + + if (!isValidInputInterpretation(inputInterpretation->getValue(), packingFactor->getValue())) + { + emitDiagnostic( + "CoopVecMatMulAdd input interpretation is invalid for the specified packing factor"); + return; + } + + if (packingFactor->getValue() == 1) + { + if (k->getValue() != inputElementCount->getValue()) + { + emitDiagnostic("CoopVecMatMulAdd k operand must match input vector element count for " + "non-packed input interpretations"); + return; + } + } + else + { + if (k->getValue() > packingFactor->getValue() * inputElementCount->getValue()) + { + emitDiagnostic( + "CoopVecMatMulAdd k operand must be less than or equal to the input vector element " + "count times the packing factor for packed input interpretations"); + return; + } + } +} + +static void validateCoopVecOuterProductAccumulate(IRInst* inst, DiagnosticSink* sink) +{ + auto emitDiagnostic = [=](const char* message) + { + sink->diagnose(Diagnostics::IrValidationFailed{ + .message = message, + .location = inst->sourceLoc, + }); + }; + + if (inst->getOperandCount() != 7) + { + emitDiagnostic("Malformed CoopVecOuterProductAccumulate operand list"); + return; + } + + auto outerProduct = as(inst); + SLANG_ASSERT(outerProduct); + + auto aType = as(outerProduct->getA()->getDataType()); + auto bType = as(outerProduct->getB()->getDataType()); + if (!aType || !bType) + { + emitDiagnostic( + "CoopVecOuterProductAccumulate a and b operands must have cooperative vector types"); + return; + } + + if (!as(outerProduct->getMemoryLayout())) + { + emitDiagnostic( + "CoopVecOuterProductAccumulate memoryLayout operand must be an integer literal"); + return; + } + + if (!as(outerProduct->getMatrixInterpretation())) + { + emitDiagnostic("CoopVecOuterProductAccumulate matrixInterpretation operand must be an " + "integer literal"); + return; + } + + if (!isValidCoopVecAccumulationOperand(outerProduct->getMatrixPtr())) + { + emitDiagnostic("CoopVecOuterProductAccumulate matrix operand must be a writable pointer, " + "RWByteAddressBuffer, or RWStructuredBuffer"); + return; + } + + if (inst->getDataType()->getOp() != kIROp_VoidType) + { + emitDiagnostic("CoopVecOuterProductAccumulate result type must be void"); + return; + } +} + +static void validateCoopVecReduceSumAccumulate(IRInst* inst, DiagnosticSink* sink) +{ + auto emitDiagnostic = [=](const char* message) + { + sink->diagnose(Diagnostics::IrValidationFailed{ + .message = message, + .location = inst->sourceLoc, + }); + }; + + if (inst->getOperandCount() != 3) + { + emitDiagnostic("Malformed CoopVecReduceSumAccumulate operand list"); + return; + } + + auto reduceSum = as(inst); + SLANG_ASSERT(reduceSum); + + if (!as(reduceSum->getValue()->getDataType())) + { + emitDiagnostic( + "CoopVecReduceSumAccumulate value operand must have a cooperative vector type"); + return; + } + + if (!isValidCoopVecAccumulationOperand(reduceSum->getBufferPtr())) + { + emitDiagnostic("CoopVecReduceSumAccumulate buffer operand must be a writable pointer, " + "RWByteAddressBuffer, or RWStructuredBuffer"); + return; + } + + if (inst->getDataType()->getOp() != kIROp_VoidType) + { + emitDiagnostic("CoopVecReduceSumAccumulate result type must be void"); + return; + } +} + + +static void validateCooperativeOperations(DiagnosticSink* sink, IRInst* inst) +{ + switch (inst->getOp()) + { + case kIROp_CoopMatMulAdd: + validateCoopMatMulAdd(inst, sink); + break; + case kIROp_CoopVecMatMulAdd: + validateCoopVecMatMulAdd(inst, sink); + break; + case kIROp_CoopVecOuterProductAccumulate: + validateCoopVecOuterProductAccumulate(inst, sink); + break; + case kIROp_CoopVecReduceSumAccumulate: + validateCoopVecReduceSumAccumulate(inst, sink); + break; + default: + break; + } + + for (auto child : inst->getModifiableChildren()) + { + validateCooperativeOperations(sink, child); + } +} + +void validateCooperativeOperations(IRModule* module, DiagnosticSink* sink) +{ + validateCooperativeOperations(sink, module->getModuleInst()); +} + } // namespace Slang diff --git a/source/slang/slang-ir-validate.h b/source/slang/slang-ir-validate.h index 666dda85c33..e867ec2fb13 100644 --- a/source/slang/slang-ir-validate.h +++ b/source/slang/slang-ir-validate.h @@ -93,4 +93,7 @@ bool validateStructuredBufferResourceTypes( DiagnosticSink* sink, TargetRequest* targetRequest); +// Validate cooperative matrix/vector operations after type specialization. +void validateCooperativeOperations(IRModule* module, DiagnosticSink* sink); + } // namespace Slang diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 1f32ccc354e..60441334e16 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -9014,6 +9014,8 @@ bool IRInst::mightHaveSideEffects(SideEffectAnalysisOptions options) case kIROp_ConstexprAnd: case kIROp_ConstexprOr: case kIROp_ConstexprSelect: + case kIROp_CoopMatMulAdd: + case kIROp_CoopVecMatMulAdd: case kIROp_Lsh: case kIROp_Rsh: case kIROp_Eql: diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index 852e71eabee..3ec816a4a1f 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -2133,7 +2133,7 @@ struct IRModule : RefObject // anything to do with serialization format // const static UInt k_minSupportedModuleVersion = 4; - const static UInt k_maxSupportedModuleVersion = 12; + const static UInt k_maxSupportedModuleVersion = 13; static_assert(k_minSupportedModuleVersion <= k_maxSupportedModuleVersion); private: diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp index c5926483af5..99d1db11685 100644 --- a/source/slang/slang-reflection-api.cpp +++ b/source/slang/slang-reflection-api.cpp @@ -144,6 +144,14 @@ static inline ProgramLayout* convert(SlangReflection* program) return (SlangReflection*)program; } +static bool isScalarType(Type* type) +{ + if (as(type)) + return true; + + return as(type) || as(type) || as(type); +} + // user attribute static unsigned int getUserAttributeCount(Decl* decl) @@ -402,7 +410,7 @@ SLANG_API SlangTypeKind spReflectionType_GetKind(SlangReflectionType* inType) // TODO(tfoley): Don't emit the same type more than once... - if (const auto basicType = as(type)) + if (isScalarType(type)) { return SLANG_TYPE_KIND_SCALAR; } @@ -686,7 +694,7 @@ SLANG_API unsigned int spReflectionType_GetRowCount(SlangReflectionType* inType) { return 1; } - else if (const auto basicType = as(type)) + else if (isScalarType(type)) { return 1; } @@ -708,7 +716,7 @@ SLANG_API unsigned int spReflectionType_GetColumnCount(SlangReflectionType* inTy { return (unsigned int)getIntVal(vectorType->getElementCount()); } - else if (const auto basicType = as(type)) + else if (isScalarType(type)) { return 1; } @@ -752,16 +760,24 @@ SLANG_API SlangScalarType spReflectionType_GetScalarType(SlangReflectionType* in CASE(Half, FLOAT16); CASE(Float, FLOAT32); CASE(Double, FLOAT64); + CASE(IntPtr, INTPTR); + CASE(UIntPtr, UINTPTR); #undef CASE default: SLANG_REFLECTION_UNEXPECTED(); return SLANG_SCALAR_TYPE_NONE; - break; } } + if (as(type)) + return SLANG_SCALAR_TYPE_BFLOAT16; + if (as(type)) + return SLANG_SCALAR_TYPE_FLOAT_E4M3; + if (as(type)) + return SLANG_SCALAR_TYPE_FLOAT_E5M2; + return SLANG_SCALAR_TYPE_NONE; } diff --git a/source/slang/slang-reflection-json.cpp b/source/slang/slang-reflection-json.cpp index e07f7013803..a4fbdbb6d2b 100644 --- a/source/slang/slang-reflection-json.cpp +++ b/source/slang/slang-reflection-json.cpp @@ -468,6 +468,11 @@ static void emitReflectionScalarTypeInfoJSON(PrettyWriter& writer, SlangScalarTy CASE(Float16, float16); CASE(Float32, float32); CASE(Float64, float64); + CASE(IntPtr, intptr); + CASE(UIntPtr, uintptr); + CASE(BFloat16, bfloat16); + CASE(FloatE4M3, float_e4m3); + CASE(FloatE5M2, float_e5m2); #undef CASE } writer << "\""; diff --git a/tests/cooperative-matrix/mat-mul-add-cuda-codegen.slang b/tests/cooperative-matrix/mat-mul-add-cuda-codegen.slang new file mode 100644 index 00000000000..8e6b2aa6b43 --- /dev/null +++ b/tests/cooperative-matrix/mat-mul-add-cuda-codegen.slang @@ -0,0 +1,18 @@ +//TEST:SIMPLE(filecheck=CHECK): -target cuda -entry computeMain -stage compute + +// CHECK: Slang_CUDA_WMMA::coopMatMulAdd< + +RWStructuredBuffer outputBuffer; + +using namespace linalg; + +[shader("compute")] +[numthreads(32, 1, 1)] +void computeMain() +{ + coopMatMulAdd( + CoopMat(3.0), + CoopMat(5.0), + CoopMat(1.0) + ).Store(outputBuffer, 0, 16); +} diff --git a/tests/cooperative-matrix/mat-mul-add-spirv-matrix-operands.slang b/tests/cooperative-matrix/mat-mul-add-spirv-matrix-operands.slang index 92f11e2e205..5afe9fa9293 100644 --- a/tests/cooperative-matrix/mat-mul-add-spirv-matrix-operands.slang +++ b/tests/cooperative-matrix/mat-mul-add-spirv-matrix-operands.slang @@ -22,7 +22,8 @@ typealias CoopMatCType = CoopMat( CoopMatAType(2), CoopMatBType(3), diff --git a/tests/cooperative-vector/matrix-mul-hlsl-codegen.slang b/tests/cooperative-vector/matrix-mul-hlsl-codegen.slang new file mode 100644 index 00000000000..368efefa09b --- /dev/null +++ b/tests/cooperative-vector/matrix-mul-hlsl-codegen.slang @@ -0,0 +1,72 @@ +//TEST:SIMPLE(filecheck=CHECK): -target hlsl -entry computeMain -profile cs_6_9 + +// CHECK: ByteAddressBuffer matrix_0 : register +// CHECK: ByteAddressBuffer bias_0 : register +// CHECK: __builtin_MatVecMul(_S{{[0-9]+}}, false, _S{{[0-9]+}}, true, SignedInt8Packed, matrix_0, int(0), SignedInt8, int(4), int(4), RowMajor, false, 4U); +// CHECK: __builtin_MatVecMulAdd(_S{{[0-9]+}}, false, _S{{[0-9]+}}, true, SignedInt8Packed, matrix_0, int(0), SignedInt8, int(4), int(4), RowMajor, false, 4U, bias_0, int(0), SignedInt32); +// CHECK: __builtin_MatVecMul(_S{{[0-9]+}}, false, _S{{[0-9]+}}, false, SignedInt8, matrix_0, int(0), SignedInt8, int(4), int(4), RowMajor, false, 4U); +// CHECK: __builtin_MatVecMulAdd(_S{{[0-9]+}}, false, _S{{[0-9]+}}, false, SignedInt8, matrix_0, int(0), SignedInt8, int(4), int(4), RowMajor, false, 4U, bias_0, int(0), SignedInt32); + +RWStructuredBuffer outputBuffer; +ByteAddressBuffer input; +ByteAddressBuffer matrix; +ByteAddressBuffer bias; + +[shader("compute")] +[numthreads(1, 1, 1)] +void computeMain() +{ + let packedVec = coopVecLoad<1, uint32_t>(input); + let vec = coopVecLoad<4, int8_t>(input); + + let resultMatMul = coopVecMatMulPacked( + packedVec, + CoopVecComponentType::SignedInt8Packed, + 4, + matrix, + 0, + CoopVecComponentType::SignedInt8, + CoopVecMatrixLayout::RowMajor, + false, + 4); + + let resultMatMulAdd = coopVecMatMulAddPacked( + packedVec, + CoopVecComponentType::SignedInt8Packed, + 4, + matrix, + 0, + CoopVecComponentType::SignedInt8, + bias, + 0, + CoopVecComponentType::SignedInt32, + CoopVecMatrixLayout::RowMajor, + false, + 4); + + let resultMatMulNonPacked = coopVecMatMul( + vec, + CoopVecComponentType::SignedInt8, + matrix, + 0, + CoopVecComponentType::SignedInt8, + CoopVecMatrixLayout::RowMajor, + false, + 4); + + let resultMatMulAddNonPacked = coopVecMatMulAdd( + vec, + CoopVecComponentType::SignedInt8, + matrix, + 0, + CoopVecComponentType::SignedInt8, + bias, + 0, + CoopVecComponentType::SignedInt32, + CoopVecMatrixLayout::RowMajor, + false, + 4); + + outputBuffer[0] = resultMatMul[0] + resultMatMulAdd[0] + resultMatMulNonPacked[0] + + resultMatMulAddNonPacked[0]; +} diff --git a/tests/cooperative-vector/matrix-mul-spirv-codegen.slang b/tests/cooperative-vector/matrix-mul-spirv-codegen.slang new file mode 100644 index 00000000000..e3243bdfc1d --- /dev/null +++ b/tests/cooperative-vector/matrix-mul-spirv-codegen.slang @@ -0,0 +1,41 @@ +//TEST:SIMPLE(filecheck=CHECK): -target spirv-asm -entry computeMain -stage compute + +// CHECK: OpCooperativeVectorMatrixMulNV {{.*}} MatrixBSignedComponentsKHR|MatrixResultSignedComponentsKHR +// CHECK: OpCooperativeVectorMatrixMulAddNV {{.*}} MatrixBSignedComponentsKHR|MatrixResultSignedComponentsKHR + +RWStructuredBuffer outputBuffer; +ByteAddressBuffer input; +ByteAddressBuffer matrix; +ByteAddressBuffer bias; + +[shader("compute")] +[numthreads(1, 1, 1)] +void computeMain() +{ + CoopVec vec = coopVecLoad<4, int8_t>(input); + + let resultMatMul = coopVecMatMul( + vec, + CoopVecComponentType::SignedInt8, + matrix, + 0, + CoopVecComponentType::SignedInt8, + CoopVecMatrixLayout::RowMajor, + false, + 4); + + let resultMatMulAdd = coopVecMatMulAdd( + vec, + CoopVecComponentType::SignedInt8, + matrix, + 0, + CoopVecComponentType::SignedInt8, + bias, + 0, + CoopVecComponentType::SignedInt32, + CoopVecMatrixLayout::RowMajor, + false, + 4); + + outputBuffer[0] = resultMatMul[0] + resultMatMulAdd[0]; +} diff --git a/tests/cooperative-vector/training-cuda-codegen.slang b/tests/cooperative-vector/training-cuda-codegen.slang new file mode 100644 index 00000000000..e4bd1709af4 --- /dev/null +++ b/tests/cooperative-vector/training-cuda-codegen.slang @@ -0,0 +1,37 @@ +//TEST:SIMPLE(filecheck=CHECK): -target cuda -capability optix_coopvec -entry computeMain -stage compute + +// CHECK: optixCoopVecOuterProductAccumulate({{.*}}, {{.*}}, (CUdeviceptr)(&({{.*}})), int(0), 32U) +// CHECK: optixCoopVecReduceSumAccumulate({{.*}}, (CUdeviceptr)(&({{.*}})), int(0)) + +RWByteAddressBuffer outerProductOutput; +RWByteAddressBuffer reduceSumOutput; + +[shader("compute")] +[numthreads(1, 1, 1)] +void computeMain() +{ + CoopVec vecA; + CoopVec vecB; + + for (int i = 0; i < vecA.getCount(); ++i) + vecA[i] = half(i + 1); + + for (int i = 0; i < vecB.getCount(); ++i) + vecB[i] = half(i + 1); + + coopVecOuterProductAccumulate( + vecA, + vecB, + outerProductOutput, + 0, + 32, + CoopVecMatrixLayout::TrainingOptimal, + CoopVecComponentType::Float16, + ); + + coopVecReduceSumAccumulate( + vecA, + reduceSumOutput, + 0, + ); +} diff --git a/tests/cooperative-vector/training-hlsl-codegen.slang b/tests/cooperative-vector/training-hlsl-codegen.slang new file mode 100644 index 00000000000..483f89dcbcb --- /dev/null +++ b/tests/cooperative-vector/training-hlsl-codegen.slang @@ -0,0 +1,39 @@ +//TEST:SIMPLE(filecheck=CHECK): -target hlsl -entry computeMain -profile cs_6_9 + +// CHECK: RWByteAddressBuffer outerProductOutput_0 : register +// CHECK: RWByteAddressBuffer reduceSumOutput_0 : register +// CHECK-DAG: __builtin_OuterProductAccumulate(_S{{[0-9]+}}, _S{{[0-9]+}}, outerProductOutput_0, int(0), Float16, TrainingOptimal, 32U); +// CHECK-DAG: __builtin_VectorAccumulate(vecA_0, reduceSumOutput_0, int(0)); + +RWByteAddressBuffer outerProductOutput; +RWByteAddressBuffer reduceSumOutput; + +[shader("compute")] +[numthreads(1, 1, 1)] +void computeMain() +{ + CoopVec vecA; + CoopVec vecB; + + for (int i = 0; i < vecA.getCount(); ++i) + vecA[i] = half(i + 1); + + for (int i = 0; i < vecB.getCount(); ++i) + vecB[i] = half(i + 1); + + coopVecOuterProductAccumulate( + vecA, + vecB, + outerProductOutput, + 0, + 32, + CoopVecMatrixLayout::TrainingOptimal, + CoopVecComponentType::Float16, + ); + + coopVecReduceSumAccumulate( + vecA, + reduceSumOutput, + 0, + ); +} diff --git a/tests/cooperative-vector/training-spirv-codegen.slang b/tests/cooperative-vector/training-spirv-codegen.slang new file mode 100644 index 00000000000..feb9da0110a --- /dev/null +++ b/tests/cooperative-vector/training-spirv-codegen.slang @@ -0,0 +1,38 @@ +//TEST:SIMPLE(filecheck=CHECK): -target spirv-asm -entry computeMain -stage compute + +// CHECK: OpCapability CooperativeVectorTrainingNV +// CHECK: OpCooperativeVectorOuterProductAccumulateNV +// CHECK: OpCooperativeVectorReduceSumAccumulateNV + +RWByteAddressBuffer outerProductOutput; +RWByteAddressBuffer reduceSumOutput; + +[shader("compute")] +[numthreads(1, 1, 1)] +void computeMain() +{ + CoopVec vecA; + CoopVec vecB; + + for (int i = 0; i < vecA.getCount(); ++i) + vecA[i] = half(i + 1); + + for (int i = 0; i < vecB.getCount(); ++i) + vecB[i] = half(i + 1); + + coopVecOuterProductAccumulate( + vecA, + vecB, + outerProductOutput, + 0, + 32, + CoopVecMatrixLayout::TrainingOptimal, + CoopVecComponentType::Float16, + ); + + coopVecReduceSumAccumulate( + vecA, + reduceSumOutput, + 0, + ); +} diff --git a/tests/cuda/optix-coopvec-packed-input-diagnostic.slang b/tests/cuda/optix-coopvec-packed-input-diagnostic.slang new file mode 100644 index 00000000000..5613a2149c2 --- /dev/null +++ b/tests/cuda/optix-coopvec-packed-input-diagnostic.slang @@ -0,0 +1,27 @@ +//DIAGNOSTIC_TEST:SIMPLE(diag=CHECK): -target cuda -capability optix_coopvec -entry computeMain -stage compute + +using namespace linalg; + +ByteAddressBuffer matrix; +RWStructuredBuffer outputBuffer; + +[numthreads(1,1,1)] +void computeMain() +{ + let packedVec = coopVecLoad<1, uint32_t>(matrix); + let result = coopVecMatMulPacked( + packedVec, + CoopVecComponentType::SignedInt8Packed, + 4, + matrix, + 0, + CoopVecComponentType::SignedInt8, + CoopVecMatrixLayout::RowMajor, + false, + 4); + + outputBuffer[0] = result[0]; +} + +//CHECK: unsupported intrinsic operation +//CHECK: intrinsic operation 'cooperative vector matrix multiply-add with packed input is not implemented yet' is not supported for the current target. diff --git a/tests/cuda/optix-coopvec-transpose-diagnostic.slang b/tests/cuda/optix-coopvec-transpose-diagnostic.slang new file mode 100644 index 00000000000..738eb2fbede --- /dev/null +++ b/tests/cuda/optix-coopvec-transpose-diagnostic.slang @@ -0,0 +1,45 @@ +//DIAGNOSTIC_TEST:SIMPLE(diag=CHECK): -target cuda -capability optix_coopvec -entry computeMain -stage compute + +using namespace linalg; + +ByteAddressBuffer input; +ByteAddressBuffer matrix; +ByteAddressBuffer bias; +RWStructuredBuffer output; + +[numthreads(1,1,1)] +void computeMain() +{ + let vec = coopVecLoad<4, half>(input); + + let resultMatMul = coopVecMatMul( + vec, + CoopVecComponentType::Float16, + matrix, + 0, + CoopVecComponentType::Float16, + CoopVecMatrixLayout::RowMajor, + true, + 8); + + let resultMatMulAdd = coopVecMatMulAdd( + vec, + CoopVecComponentType::Float16, + matrix, + 0, + CoopVecComponentType::Float16, + bias, + 0, + CoopVecComponentType::Float16, + CoopVecMatrixLayout::RowMajor, + true, + 8); + + output[0] = resultMatMul[0]; + output[1] = resultMatMulAdd[0]; +} + +//CHECK: unsupported intrinsic operation +//CHECK: intrinsic operation 'cooperative vector matrix multiply-add with transpose requires Float16 matrix interpretation and InferencingOptimal or TrainingOptimal matrix layout for OptiX' is not supported for the current target. +//CHECK: unsupported intrinsic operation +//CHECK: intrinsic operation 'cooperative vector matrix multiply-add with transpose requires Float16 matrix interpretation and InferencingOptimal or TrainingOptimal matrix layout for OptiX' is not supported for the current target. diff --git a/tests/cuda/optix-coopvec.slang b/tests/cuda/optix-coopvec.slang index 120620295f0..4896ae8cfd4 100644 --- a/tests/cuda/optix-coopvec.slang +++ b/tests/cuda/optix-coopvec.slang @@ -21,8 +21,8 @@ // CHECK: (optixCoopVecMax((_S{{[0-9]+}}), (_S{{[0-9]+}}))) // CHECK: (optixCoopVecMin((_S{{[0-9]+}}), (_S{{[0-9]+}}))) // CHECK: (optixCoopVecMul((_S{{[0-9]+}}), (_S{{[0-9]+}}))) -// CHECK: optixCoopVecOuterProductAccumulate((_S{{[0-9]+}}), (_S{{[0-9]+}}), (CUdeviceptr)(&(globalParams_{{[0-9]+}}->outputMat_{{[0-9]+}})), (int(0)), (32U)) -// CHECK: optixCoopVecReduceSumAccumulate((_S{{[0-9]+}}), (CUdeviceptr)(&(globalParams_{{[0-9]+}}->outputMat{{[0-9]+}}_{{[0-9]+}})), (int(0))) +// CHECK: optixCoopVecOuterProductAccumulate(_S{{[0-9]+}}, _S{{[0-9]+}}, (CUdeviceptr)(&(globalParams_{{[0-9]+}}->outputMat_{{[0-9]+}})), int(0), 32U) +// CHECK: optixCoopVecReduceSumAccumulate(_S{{[0-9]+}}, (CUdeviceptr)(&(globalParams_{{[0-9]+}}->outputMat{{[0-9]+}}_{{[0-9]+}})), int(0)) // CHECK: (optixCoopVecStep((_S{{[0-9]+}}), (_S{{[0-9]+}}))) // CHECK: (optixCoopVecSub((_S{{[0-9]+}}), (_S{{[0-9]+}}))) // CHECK: (optixCoopVecLog2((_S{{[0-9]+}}))) @@ -171,21 +171,23 @@ void closestHitShader(inout RayPayload payload, in BuiltInTriangleIntersectionAt CoopVec resultMin = min(vec1, vec2); CoopVec resultVecMul = vec1 * vec2; + CoopVec trainingVec1 = coopVecLoad<4, half>(input1); + CoopVec trainingVec2 = coopVecLoad<4, half>(input2); outputMat.Store(0, float(1)); coopVecOuterProductAccumulate( - vec1, - vec2, + trainingVec1, + trainingVec2, outputMat, 0, 32, - CoopVecMatrixLayout::RowMajor, - CoopVecComponentType::Float32, + CoopVecMatrixLayout::TrainingOptimal, + CoopVecComponentType::Float16, ); - outputMat2.Store(0, float(1)); + outputMat2.Store(0, float(1)); coopVecReduceSumAccumulate( - vec1, + trainingVec1, outputMat2, 0, ); diff --git a/tools/gfx/slang.slang b/tools/gfx/slang.slang index 75fa8a01fdd..3b0d56bb06d 100644 --- a/tools/gfx/slang.slang +++ b/tools/gfx/slang.slang @@ -277,6 +277,11 @@ public enum SlangScalarType UINT8, INT16, UINT16, + INTPTR, + UINTPTR, + BFLOAT16, + FLOAT_E4M3, + FLOAT_E5M2, }; public struct TypeReflection diff --git a/tools/render-test/shader-input-layout.cpp b/tools/render-test/shader-input-layout.cpp index 98422066b24..f47d2a06b8c 100644 --- a/tools/render-test/shader-input-layout.cpp +++ b/tools/render-test/shader-input-layout.cpp @@ -3,6 +3,7 @@ #include "shader-input-layout.h" +#include "core/slang-math.h" #include "core/slang-token-reader.h" #include "core/slang-type-text-util.h" @@ -1147,6 +1148,39 @@ void ShaderInputLayout::parse(RandomGenerator* rand, const char* source) } break; } + case ScalarType::BFloat16: + { + auto ptr = (const uint16_t*)data; + const size_t size = sizeInBytes / sizeof(ptr[0]); + for (size_t i = 0; i < size; ++i) + { + const float v = BFloat16ToFloat(ptr[i]); + writer.print("%f\n", v); + } + break; + } + case ScalarType::FloatE4M3: + { + auto ptr = (const uint8_t*)data; + const size_t size = sizeInBytes / sizeof(ptr[0]); + for (size_t i = 0; i < size; ++i) + { + const float v = FloatE4M3ToFloat(ptr[i]); + writer.print("%f\n", v); + } + break; + } + case ScalarType::FloatE5M2: + { + auto ptr = (const uint8_t*)data; + const size_t size = sizeInBytes / sizeof(ptr[0]); + for (size_t i = 0; i < size; ++i) + { + const float v = FloatE5M2ToFloat(ptr[i]); + writer.print("%f\n", v); + } + break; + } #define CASE(SLANG_TYPE, C_TYPE, FORMAT) \ case ScalarType::SLANG_TYPE: \ { \ @@ -1167,6 +1201,8 @@ void ShaderInputLayout::parse(RandomGenerator* rand, const char* source) CASE(Int32, int32_t, PRId32); CASE(UInt64, uint64_t, PRIu64); CASE(Int64, int64_t, PRId64); + CASE(UIntPtr, uintptr_t, PRIuPTR); + CASE(IntPtr, intptr_t, PRIdPTR); CASE(Float32, float, "f"); CASE(Float64, double, "f"); #undef CASE @@ -1452,6 +1488,9 @@ void generateTextureDataRGB8(TextureData& output, const InputTextureDesc& inputD case SLANG_SCALAR_TYPE_FLOAT64: case SLANG_SCALAR_TYPE_FLOAT32: case SLANG_SCALAR_TYPE_FLOAT16: + case SLANG_SCALAR_TYPE_BFLOAT16: + case SLANG_SCALAR_TYPE_FLOAT_E4M3: + case SLANG_SCALAR_TYPE_FLOAT_E5M2: type = SimpleScalarType::kFloat; break; default: diff --git a/tools/slang-test/slang-test-main.cpp b/tools/slang-test/slang-test-main.cpp index adcc3cb8fc9..ea23cf04fa4 100644 --- a/tools/slang-test/slang-test-main.cpp +++ b/tools/slang-test/slang-test-main.cpp @@ -4050,6 +4050,9 @@ static SlangResult _compareWithType( case ScalarType::Float16: case ScalarType::Float32: case ScalarType::Float64: + case ScalarType::BFloat16: + case ScalarType::FloatE4M3: + case ScalarType::FloatE5M2: { // Compare as double diff --git a/tools/slang-unit-test/unit-test-special-scalar-reflection.cpp b/tools/slang-unit-test/unit-test-special-scalar-reflection.cpp new file mode 100644 index 00000000000..fe3f12d02a8 --- /dev/null +++ b/tools/slang-unit-test/unit-test-special-scalar-reflection.cpp @@ -0,0 +1,144 @@ +// unit-test-special-scalar-reflection.cpp + +#include "slang-com-ptr.h" +#include "slang.h" +#include "unit-test/slang-unit-test.h" + +using namespace Slang; + +SLANG_UNIT_TEST(specialScalarReflection) +{ + const char* userSourceBody = R"( + struct TestStruct + { + BFloat16 bf; + FloatE4M3 e4; + FloatE5M2 e5; + intptr_t ip; + uintptr_t up; + vector vbf; + vector ve4; + vector ve5; + vector vip; + vector vup; + }; + + StructuredBuffer gData; + )"; + + ComPtr globalSession; + SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK); + + slang::TargetDesc targetDesc = {}; + targetDesc.format = SLANG_CUDA_SOURCE; + + slang::SessionDesc sessionDesc = {}; + sessionDesc.targetCount = 1; + sessionDesc.targets = &targetDesc; + + ComPtr session; + SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK); + + ComPtr diagnosticBlob; + auto module = session->loadModuleFromSourceString( + "m", + "m.slang", + userSourceBody, + diagnosticBlob.writeRef()); + SLANG_CHECK(module != nullptr); + + auto reflection = module->getLayout(); + SLANG_CHECK(reflection != nullptr); + + auto gDataLayout = reflection->getParameterByIndex(0); + SLANG_CHECK(gDataLayout != nullptr); + + auto gDataType = gDataLayout->getType(); + SLANG_CHECK(gDataType != nullptr); + SLANG_CHECK(gDataType->getKind() == slang::TypeReflection::Kind::Resource); + + auto resultType = gDataType->getResourceResultType(); + SLANG_CHECK(resultType != nullptr); + SLANG_CHECK(resultType->getKind() == slang::TypeReflection::Kind::Struct); + SLANG_CHECK_ABORT(resultType->getFieldCount() == 10); + + auto bfField = resultType->getFieldByIndex(0); + auto e4Field = resultType->getFieldByIndex(1); + auto e5Field = resultType->getFieldByIndex(2); + auto ipField = resultType->getFieldByIndex(3); + auto upField = resultType->getFieldByIndex(4); + auto vbfField = resultType->getFieldByIndex(5); + auto ve4Field = resultType->getFieldByIndex(6); + auto ve5Field = resultType->getFieldByIndex(7); + auto vipField = resultType->getFieldByIndex(8); + auto vupField = resultType->getFieldByIndex(9); + + SLANG_CHECK(bfField != nullptr); + SLANG_CHECK(e4Field != nullptr); + SLANG_CHECK(e5Field != nullptr); + SLANG_CHECK(ipField != nullptr); + SLANG_CHECK(upField != nullptr); + SLANG_CHECK(vbfField != nullptr); + SLANG_CHECK(ve4Field != nullptr); + SLANG_CHECK(ve5Field != nullptr); + SLANG_CHECK(vipField != nullptr); + SLANG_CHECK(vupField != nullptr); + + auto bfType = bfField->getType(); + auto e4Type = e4Field->getType(); + auto e5Type = e5Field->getType(); + auto ipType = ipField->getType(); + auto upType = upField->getType(); + auto vbfType = vbfField->getType(); + auto ve4Type = ve4Field->getType(); + auto ve5Type = ve5Field->getType(); + auto vipType = vipField->getType(); + auto vupType = vupField->getType(); + + SLANG_CHECK(bfType->getKind() == slang::TypeReflection::Kind::Scalar); + SLANG_CHECK(e4Type->getKind() == slang::TypeReflection::Kind::Scalar); + SLANG_CHECK(e5Type->getKind() == slang::TypeReflection::Kind::Scalar); + SLANG_CHECK(ipType->getKind() == slang::TypeReflection::Kind::Scalar); + SLANG_CHECK(upType->getKind() == slang::TypeReflection::Kind::Scalar); + + SLANG_CHECK(bfType->getScalarType() == slang::TypeReflection::ScalarType::BFloat16); + SLANG_CHECK(e4Type->getScalarType() == slang::TypeReflection::ScalarType::FloatE4M3); + SLANG_CHECK(e5Type->getScalarType() == slang::TypeReflection::ScalarType::FloatE5M2); + SLANG_CHECK(ipType->getScalarType() == slang::TypeReflection::ScalarType::IntPtr); + SLANG_CHECK(upType->getScalarType() == slang::TypeReflection::ScalarType::UIntPtr); + + SLANG_CHECK(bfType->getRowCount() == 1); + SLANG_CHECK(bfType->getColumnCount() == 1); + SLANG_CHECK(e4Type->getRowCount() == 1); + SLANG_CHECK(e4Type->getColumnCount() == 1); + SLANG_CHECK(e5Type->getRowCount() == 1); + SLANG_CHECK(e5Type->getColumnCount() == 1); + SLANG_CHECK(ipType->getRowCount() == 1); + SLANG_CHECK(ipType->getColumnCount() == 1); + SLANG_CHECK(upType->getRowCount() == 1); + SLANG_CHECK(upType->getColumnCount() == 1); + + SLANG_CHECK(vbfType->getKind() == slang::TypeReflection::Kind::Vector); + SLANG_CHECK(ve4Type->getKind() == slang::TypeReflection::Kind::Vector); + SLANG_CHECK(ve5Type->getKind() == slang::TypeReflection::Kind::Vector); + SLANG_CHECK(vipType->getKind() == slang::TypeReflection::Kind::Vector); + SLANG_CHECK(vupType->getKind() == slang::TypeReflection::Kind::Vector); + + auto vbfElementType = vbfType->getElementType(); + auto ve4ElementType = ve4Type->getElementType(); + auto ve5ElementType = ve5Type->getElementType(); + auto vipElementType = vipType->getElementType(); + auto vupElementType = vupType->getElementType(); + + SLANG_CHECK(vbfElementType != nullptr); + SLANG_CHECK(ve4ElementType != nullptr); + SLANG_CHECK(ve5ElementType != nullptr); + SLANG_CHECK(vipElementType != nullptr); + SLANG_CHECK(vupElementType != nullptr); + + SLANG_CHECK(vbfElementType->getScalarType() == slang::TypeReflection::ScalarType::BFloat16); + SLANG_CHECK(ve4ElementType->getScalarType() == slang::TypeReflection::ScalarType::FloatE4M3); + SLANG_CHECK(ve5ElementType->getScalarType() == slang::TypeReflection::ScalarType::FloatE5M2); + SLANG_CHECK(vipElementType->getScalarType() == slang::TypeReflection::ScalarType::IntPtr); + SLANG_CHECK(vupElementType->getScalarType() == slang::TypeReflection::ScalarType::UIntPtr); +}