diff --git a/prelude/slang-cuda-prelude.h b/prelude/slang-cuda-prelude.h index 38bdbace698..b79f04ebd7e 100644 --- a/prelude/slang-cuda-prelude.h +++ b/prelude/slang-cuda-prelude.h @@ -4960,40 +4960,97 @@ _slang_waveClusteredRotate(bool4 value, unsigned int delta, unsigned int cluster #ifdef SLANG_CUDA_ENABLE_OPTIX template -struct SlangToOptixComponentType { +struct SlangToOptixComponentType +{ static constexpr OptixCoopVecElemType value = OPTIX_COOP_VEC_ELEM_TYPE_FLOAT32; // Default }; -template<> struct SlangToOptixComponentType<0> { static constexpr OptixCoopVecElemType value = OPTIX_COOP_VEC_ELEM_TYPE_FLOAT8_E4M3; }; // FloatE4M3 -template<> struct SlangToOptixComponentType<1> { static constexpr OptixCoopVecElemType value = OPTIX_COOP_VEC_ELEM_TYPE_FLOAT8_E5M2; }; // FloatE5M2 -template<> struct SlangToOptixComponentType<2> { static constexpr OptixCoopVecElemType value = OPTIX_COOP_VEC_ELEM_TYPE_FLOAT16; }; // Float16 -template<> struct SlangToOptixComponentType<3> { static constexpr OptixCoopVecElemType value = OPTIX_COOP_VEC_ELEM_TYPE_FLOAT32; }; // Float32 -template<> struct SlangToOptixComponentType<5> { static constexpr OptixCoopVecElemType value = OPTIX_COOP_VEC_ELEM_TYPE_INT8; }; // SignedInt8 -template<> struct SlangToOptixComponentType<7> { static constexpr OptixCoopVecElemType value = OPTIX_COOP_VEC_ELEM_TYPE_INT32; }; // SignedInt32 -template<> struct SlangToOptixComponentType<10> { static constexpr OptixCoopVecElemType value = OPTIX_COOP_VEC_ELEM_TYPE_UINT8; }; // UnsignedInt8 -template<> struct SlangToOptixComponentType<12> { static constexpr OptixCoopVecElemType value = OPTIX_COOP_VEC_ELEM_TYPE_UINT32; }; // UnsignedInt32 +template<> +struct SlangToOptixComponentType<0> +{ + static constexpr OptixCoopVecElemType value = OPTIX_COOP_VEC_ELEM_TYPE_FLOAT8_E4M3; +}; // FloatE4M3 +template<> +struct SlangToOptixComponentType<1> +{ + static constexpr OptixCoopVecElemType value = OPTIX_COOP_VEC_ELEM_TYPE_FLOAT8_E5M2; +}; // FloatE5M2 +template<> +struct SlangToOptixComponentType<2> +{ + static constexpr OptixCoopVecElemType value = OPTIX_COOP_VEC_ELEM_TYPE_FLOAT16; +}; // Float16 +template<> +struct SlangToOptixComponentType<3> +{ + static constexpr OptixCoopVecElemType value = OPTIX_COOP_VEC_ELEM_TYPE_FLOAT32; +}; // Float32 +template<> +struct SlangToOptixComponentType<5> +{ + static constexpr OptixCoopVecElemType value = OPTIX_COOP_VEC_ELEM_TYPE_INT8; +}; // SignedInt8 +template<> +struct SlangToOptixComponentType<7> +{ + static constexpr OptixCoopVecElemType value = OPTIX_COOP_VEC_ELEM_TYPE_INT32; +}; // SignedInt32 +template<> +struct SlangToOptixComponentType<10> +{ + static constexpr OptixCoopVecElemType value = OPTIX_COOP_VEC_ELEM_TYPE_UINT8; +}; // UnsignedInt8 +template<> +struct SlangToOptixComponentType<12> +{ + static constexpr OptixCoopVecElemType value = OPTIX_COOP_VEC_ELEM_TYPE_UINT32; +}; // UnsignedInt32 template -struct SlangToOptixMatrixLayout { - static constexpr OptixCoopVecMatrixLayout value = OPTIX_COOP_VEC_MATRIX_LAYOUT_ROW_MAJOR; // Default +struct SlangToOptixMatrixLayout +{ + static constexpr OptixCoopVecMatrixLayout value = + OPTIX_COOP_VEC_MATRIX_LAYOUT_ROW_MAJOR; // Default }; -template<> struct SlangToOptixMatrixLayout<0> { static constexpr OptixCoopVecMatrixLayout value = OPTIX_COOP_VEC_MATRIX_LAYOUT_ROW_MAJOR; }; // RowMajor -template<> struct SlangToOptixMatrixLayout<1> { static constexpr OptixCoopVecMatrixLayout value = OPTIX_COOP_VEC_MATRIX_LAYOUT_COLUMN_MAJOR; }; // ColumnMajor -template<> struct SlangToOptixMatrixLayout<2> { static constexpr OptixCoopVecMatrixLayout value = OPTIX_COOP_VEC_MATRIX_LAYOUT_INFERENCING_OPTIMAL; }; // InferencingOptimal -template<> struct SlangToOptixMatrixLayout<3> { static constexpr OptixCoopVecMatrixLayout value = OPTIX_COOP_VEC_MATRIX_LAYOUT_TRAINING_OPTIMAL; }; // TrainingOptimal +template<> +struct SlangToOptixMatrixLayout<0> +{ + static constexpr OptixCoopVecMatrixLayout value = OPTIX_COOP_VEC_MATRIX_LAYOUT_ROW_MAJOR; +}; // RowMajor +template<> +struct SlangToOptixMatrixLayout<1> +{ + static constexpr OptixCoopVecMatrixLayout value = OPTIX_COOP_VEC_MATRIX_LAYOUT_COLUMN_MAJOR; +}; // ColumnMajor +template<> +struct SlangToOptixMatrixLayout<2> +{ + static constexpr OptixCoopVecMatrixLayout value = + OPTIX_COOP_VEC_MATRIX_LAYOUT_INFERENCING_OPTIMAL; +}; // InferencingOptimal +template<> +struct SlangToOptixMatrixLayout<3> +{ + static constexpr OptixCoopVecMatrixLayout value = OPTIX_COOP_VEC_MATRIX_LAYOUT_TRAINING_OPTIMAL; +}; // TrainingOptimal // Template trait to extract vector size from OptixCoopVec template struct OptixCoopVecTraits; template -struct OptixCoopVecTraits> { +struct OptixCoopVecTraits> +{ static constexpr unsigned size = N; }; -template +template< + typename VecTOut, + typename VecTIn, + unsigned inputInterpretation, + unsigned matrixInterpretation, + unsigned matrixLayout> __forceinline__ __device__ VecTOut slangOptixCoopVecMatMul( const VecTIn& inputVector, CUdeviceptr matrix, @@ -5001,20 +5058,32 @@ __forceinline__ __device__ VecTOut slangOptixCoopVecMatMul( bool transpose, unsigned matrixStride) { - constexpr unsigned N = OptixCoopVecTraits::size; // Output vector size - constexpr unsigned K = OptixCoopVecTraits::size; // Input vector size + constexpr unsigned N = OptixCoopVecTraits::size; // Output vector size + constexpr unsigned K = OptixCoopVecTraits::size; // Input vector size - return optixCoopVecMatMul::value, - SlangToOptixMatrixLayout::value, - false, N, K, - SlangToOptixComponentType::value> - (inputVector, matrix, matrixOffset, matrixStride); + return optixCoopVecMatMul< + VecTOut, + VecTIn, + SlangToOptixComponentType::value, + SlangToOptixMatrixLayout::value, + false, + N, + K, + SlangToOptixComponentType::value>( + inputVector, + matrix, + matrixOffset, + matrixStride); } // OptiX cooperative vector matrix multiplication wrapper (WITH bias - 6 runtime params) -template +template< + typename VecTOut, + typename VecTIn, + unsigned inputInterpretation, + unsigned matrixInterpretation, + unsigned matrixLayout, + unsigned biasInterpretation> __forceinline__ __device__ VecTOut slangOptixCoopVecMatMul( const VecTIn& inputVector, CUdeviceptr matrix, @@ -5023,38 +5092,59 @@ __forceinline__ __device__ VecTOut slangOptixCoopVecMatMul( unsigned biasOffset, unsigned matrixStride) { - constexpr unsigned N = OptixCoopVecTraits::size; // Output vector size - constexpr unsigned K = OptixCoopVecTraits::size; // Input vector size + constexpr unsigned N = OptixCoopVecTraits::size; // Output vector size + constexpr unsigned K = OptixCoopVecTraits::size; // Input vector size // Call OptiX SDK with bias (6 runtime parameters) - return optixCoopVecMatMul::value, - SlangToOptixMatrixLayout::value, - false, N, K, - SlangToOptixComponentType::value, - SlangToOptixComponentType::value> - (inputVector, matrix, matrixOffset, bias, biasOffset, matrixStride); -} - -// OptiX cooperative vector matrix multiplication wrapper (WITHOUT bias, 4 runtime params - StructuredBuffer variant) -template + return optixCoopVecMatMul< + VecTOut, + VecTIn, + SlangToOptixComponentType::value, + SlangToOptixMatrixLayout::value, + false, + N, + K, + SlangToOptixComponentType::value, + SlangToOptixComponentType::value>( + inputVector, + matrix, + matrixOffset, + bias, + biasOffset, + matrixStride); +} + +// OptiX cooperative vector matrix multiplication wrapper (WITHOUT bias, 4 runtime params - +// StructuredBuffer variant) +template< + typename VecTOut, + typename VecTIn, + unsigned inputInterpretation, + unsigned matrixInterpretation, + unsigned matrixLayout> __forceinline__ __device__ VecTOut slangOptixCoopVecMatMul( const VecTIn& inputVector, CUdeviceptr matrix, unsigned matrixOffset, unsigned matrixStride) { - constexpr unsigned N = OptixCoopVecTraits::size; // Output vector size - constexpr unsigned K = OptixCoopVecTraits::size; // Input vector size + constexpr unsigned N = OptixCoopVecTraits::size; // Output vector size + constexpr unsigned K = OptixCoopVecTraits::size; // Input vector size // Call OptiX SDK without bias and without transpose (4 runtime parameters) - return optixCoopVecMatMul::value, - SlangToOptixMatrixLayout::value, - false, N, K, - SlangToOptixComponentType::value> - (inputVector, matrix, matrixOffset, matrixStride); + return optixCoopVecMatMul< + VecTOut, + VecTIn, + SlangToOptixComponentType::value, + SlangToOptixMatrixLayout::value, + false, + N, + K, + SlangToOptixComponentType::value>( + inputVector, + matrix, + matrixOffset, + matrixStride); } #endif // SLANG_CUDA_ENABLE_OPTIX