Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 140 additions & 50 deletions prelude/slang-cuda-prelude.h
Original file line number Diff line number Diff line change
Expand Up @@ -4960,61 +4960,130 @@ _slang_waveClusteredRotate(bool4 value, unsigned int delta, unsigned int cluster
#ifdef SLANG_CUDA_ENABLE_OPTIX

template<unsigned SlangEnum>
struct SlangToOptixComponentType {
struct SlangToOptixComponentType
{
static constexpr OptixCoopVecElemType value = OPTIX_COOP_VEC_ELEM_TYPE_FLOAT32; // Default
};

template<> struct SlangToOptixComponentType<0> { static constexpr OptixCoopVecElemType value = OPTIX_COOP_VEC_ELEM_TYPE_FLOAT8_E4M3; }; // FloatE4M3
template<> struct SlangToOptixComponentType<1> { static constexpr OptixCoopVecElemType value = OPTIX_COOP_VEC_ELEM_TYPE_FLOAT8_E5M2; }; // FloatE5M2
template<> struct SlangToOptixComponentType<2> { static constexpr OptixCoopVecElemType value = OPTIX_COOP_VEC_ELEM_TYPE_FLOAT16; }; // Float16
template<> struct SlangToOptixComponentType<3> { static constexpr OptixCoopVecElemType value = OPTIX_COOP_VEC_ELEM_TYPE_FLOAT32; }; // Float32
template<> struct SlangToOptixComponentType<5> { static constexpr OptixCoopVecElemType value = OPTIX_COOP_VEC_ELEM_TYPE_INT8; }; // SignedInt8
template<> struct SlangToOptixComponentType<7> { static constexpr OptixCoopVecElemType value = OPTIX_COOP_VEC_ELEM_TYPE_INT32; }; // SignedInt32
template<> struct SlangToOptixComponentType<10> { static constexpr OptixCoopVecElemType value = OPTIX_COOP_VEC_ELEM_TYPE_UINT8; }; // UnsignedInt8
template<> struct SlangToOptixComponentType<12> { static constexpr OptixCoopVecElemType value = OPTIX_COOP_VEC_ELEM_TYPE_UINT32; }; // UnsignedInt32
template<>
struct SlangToOptixComponentType<0>
{
static constexpr OptixCoopVecElemType value = OPTIX_COOP_VEC_ELEM_TYPE_FLOAT8_E4M3;
}; // FloatE4M3
template<>
struct SlangToOptixComponentType<1>
{
static constexpr OptixCoopVecElemType value = OPTIX_COOP_VEC_ELEM_TYPE_FLOAT8_E5M2;
}; // FloatE5M2
template<>
struct SlangToOptixComponentType<2>
{
static constexpr OptixCoopVecElemType value = OPTIX_COOP_VEC_ELEM_TYPE_FLOAT16;
}; // Float16
template<>
struct SlangToOptixComponentType<3>
{
static constexpr OptixCoopVecElemType value = OPTIX_COOP_VEC_ELEM_TYPE_FLOAT32;
}; // Float32
template<>
struct SlangToOptixComponentType<5>
{
static constexpr OptixCoopVecElemType value = OPTIX_COOP_VEC_ELEM_TYPE_INT8;
}; // SignedInt8
template<>
struct SlangToOptixComponentType<7>
{
static constexpr OptixCoopVecElemType value = OPTIX_COOP_VEC_ELEM_TYPE_INT32;
}; // SignedInt32
template<>
struct SlangToOptixComponentType<10>
{
static constexpr OptixCoopVecElemType value = OPTIX_COOP_VEC_ELEM_TYPE_UINT8;
}; // UnsignedInt8
template<>
struct SlangToOptixComponentType<12>
{
static constexpr OptixCoopVecElemType value = OPTIX_COOP_VEC_ELEM_TYPE_UINT32;
}; // UnsignedInt32

template<unsigned SlangEnum>
struct SlangToOptixMatrixLayout {
static constexpr OptixCoopVecMatrixLayout value = OPTIX_COOP_VEC_MATRIX_LAYOUT_ROW_MAJOR; // Default
struct SlangToOptixMatrixLayout
{
static constexpr OptixCoopVecMatrixLayout value =
OPTIX_COOP_VEC_MATRIX_LAYOUT_ROW_MAJOR; // Default
};

template<> struct SlangToOptixMatrixLayout<0> { static constexpr OptixCoopVecMatrixLayout value = OPTIX_COOP_VEC_MATRIX_LAYOUT_ROW_MAJOR; }; // RowMajor
template<> struct SlangToOptixMatrixLayout<1> { static constexpr OptixCoopVecMatrixLayout value = OPTIX_COOP_VEC_MATRIX_LAYOUT_COLUMN_MAJOR; }; // ColumnMajor
template<> struct SlangToOptixMatrixLayout<2> { static constexpr OptixCoopVecMatrixLayout value = OPTIX_COOP_VEC_MATRIX_LAYOUT_INFERENCING_OPTIMAL; }; // InferencingOptimal
template<> struct SlangToOptixMatrixLayout<3> { static constexpr OptixCoopVecMatrixLayout value = OPTIX_COOP_VEC_MATRIX_LAYOUT_TRAINING_OPTIMAL; }; // TrainingOptimal
template<>
struct SlangToOptixMatrixLayout<0>
{
static constexpr OptixCoopVecMatrixLayout value = OPTIX_COOP_VEC_MATRIX_LAYOUT_ROW_MAJOR;
}; // RowMajor
template<>
struct SlangToOptixMatrixLayout<1>
{
static constexpr OptixCoopVecMatrixLayout value = OPTIX_COOP_VEC_MATRIX_LAYOUT_COLUMN_MAJOR;
}; // ColumnMajor
template<>
struct SlangToOptixMatrixLayout<2>
{
static constexpr OptixCoopVecMatrixLayout value =
OPTIX_COOP_VEC_MATRIX_LAYOUT_INFERENCING_OPTIMAL;
}; // InferencingOptimal
template<>
struct SlangToOptixMatrixLayout<3>
{
static constexpr OptixCoopVecMatrixLayout value = OPTIX_COOP_VEC_MATRIX_LAYOUT_TRAINING_OPTIMAL;
}; // TrainingOptimal

// Template trait to extract vector size from OptixCoopVec<T, N>
template<typename T>
struct OptixCoopVecTraits;

template<typename T, unsigned N>
struct OptixCoopVecTraits<OptixCoopVec<T, N>> {
struct OptixCoopVecTraits<OptixCoopVec<T, N>>
{
static constexpr unsigned size = N;
};

template<typename VecTOut, typename VecTIn,
unsigned inputInterpretation, unsigned matrixInterpretation, unsigned matrixLayout>
template<
typename VecTOut,
typename VecTIn,
unsigned inputInterpretation,
unsigned matrixInterpretation,
unsigned matrixLayout>
__forceinline__ __device__ VecTOut slangOptixCoopVecMatMul(
const VecTIn& inputVector,
CUdeviceptr matrix,
unsigned matrixOffset,
bool transpose,
unsigned matrixStride)
{
constexpr unsigned N = OptixCoopVecTraits<VecTOut>::size; // Output vector size
constexpr unsigned K = OptixCoopVecTraits<VecTIn>::size; // Input vector size
constexpr unsigned N = OptixCoopVecTraits<VecTOut>::size; // Output vector size
constexpr unsigned K = OptixCoopVecTraits<VecTIn>::size; // Input vector size

return optixCoopVecMatMul<VecTOut, VecTIn,
SlangToOptixComponentType<inputInterpretation>::value,
SlangToOptixMatrixLayout<matrixLayout>::value,
false, N, K,
SlangToOptixComponentType<matrixInterpretation>::value>
(inputVector, matrix, matrixOffset, matrixStride);
return optixCoopVecMatMul<
VecTOut,
VecTIn,
SlangToOptixComponentType<inputInterpretation>::value,
SlangToOptixMatrixLayout<matrixLayout>::value,
false,
N,
K,
SlangToOptixComponentType<matrixInterpretation>::value>(
inputVector,
matrix,
matrixOffset,
matrixStride);
}

// OptiX cooperative vector matrix multiplication wrapper (WITH bias - 6 runtime params)
template<typename VecTOut, typename VecTIn,
unsigned inputInterpretation, unsigned matrixInterpretation, unsigned matrixLayout, unsigned biasInterpretation>
template<
typename VecTOut,
typename VecTIn,
unsigned inputInterpretation,
unsigned matrixInterpretation,
unsigned matrixLayout,
unsigned biasInterpretation>
__forceinline__ __device__ VecTOut slangOptixCoopVecMatMul(
const VecTIn& inputVector,
CUdeviceptr matrix,
Expand All @@ -5023,38 +5092,59 @@ __forceinline__ __device__ VecTOut slangOptixCoopVecMatMul(
unsigned biasOffset,
unsigned matrixStride)
{
constexpr unsigned N = OptixCoopVecTraits<VecTOut>::size; // Output vector size
constexpr unsigned K = OptixCoopVecTraits<VecTIn>::size; // Input vector size
constexpr unsigned N = OptixCoopVecTraits<VecTOut>::size; // Output vector size
constexpr unsigned K = OptixCoopVecTraits<VecTIn>::size; // Input vector size

// Call OptiX SDK with bias (6 runtime parameters)
return optixCoopVecMatMul<VecTOut, VecTIn,
SlangToOptixComponentType<inputInterpretation>::value,
SlangToOptixMatrixLayout<matrixLayout>::value,
false, N, K,
SlangToOptixComponentType<matrixInterpretation>::value,
SlangToOptixComponentType<biasInterpretation>::value>
(inputVector, matrix, matrixOffset, bias, biasOffset, matrixStride);
}

// OptiX cooperative vector matrix multiplication wrapper (WITHOUT bias, 4 runtime params - StructuredBuffer variant)
template<typename VecTOut, typename VecTIn,
unsigned inputInterpretation, unsigned matrixInterpretation, unsigned matrixLayout>
return optixCoopVecMatMul<
VecTOut,
VecTIn,
SlangToOptixComponentType<inputInterpretation>::value,
SlangToOptixMatrixLayout<matrixLayout>::value,
false,
N,
K,
SlangToOptixComponentType<matrixInterpretation>::value,
SlangToOptixComponentType<biasInterpretation>::value>(
inputVector,
matrix,
matrixOffset,
bias,
biasOffset,
matrixStride);
}

// OptiX cooperative vector matrix multiplication wrapper (WITHOUT bias, 4 runtime params -
// StructuredBuffer variant)
template<
typename VecTOut,
typename VecTIn,
unsigned inputInterpretation,
unsigned matrixInterpretation,
unsigned matrixLayout>
__forceinline__ __device__ VecTOut slangOptixCoopVecMatMul(
const VecTIn& inputVector,
CUdeviceptr matrix,
unsigned matrixOffset,
unsigned matrixStride)
{
constexpr unsigned N = OptixCoopVecTraits<VecTOut>::size; // Output vector size
constexpr unsigned K = OptixCoopVecTraits<VecTIn>::size; // Input vector size
constexpr unsigned N = OptixCoopVecTraits<VecTOut>::size; // Output vector size
constexpr unsigned K = OptixCoopVecTraits<VecTIn>::size; // Input vector size

// Call OptiX SDK without bias and without transpose (4 runtime parameters)
return optixCoopVecMatMul<VecTOut, VecTIn,
SlangToOptixComponentType<inputInterpretation>::value,
SlangToOptixMatrixLayout<matrixLayout>::value,
false, N, K,
SlangToOptixComponentType<matrixInterpretation>::value>
(inputVector, matrix, matrixOffset, matrixStride);
return optixCoopVecMatMul<
VecTOut,
VecTIn,
SlangToOptixComponentType<inputInterpretation>::value,
SlangToOptixMatrixLayout<matrixLayout>::value,
false,
N,
K,
SlangToOptixComponentType<matrixInterpretation>::value>(
inputVector,
matrix,
matrixOffset,
matrixStride);
}

#endif // SLANG_CUDA_ENABLE_OPTIX