Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 181 additions & 0 deletions prelude/slang-cuda-prelude.h
Original file line number Diff line number Diff line change
Expand Up @@ -4955,3 +4955,184 @@ _slang_waveClusteredRotate(bool4 value, unsigned int delta, unsigned int cluster
}

#undef SLANG_WAVE_CLUSTERED_ROTATE_IMPL


// ---------------------- OptiX Cooperative Vector Wrappers --------------------------------------
#ifdef SLANG_CUDA_ENABLE_OPTIX

// Constexpr function to map Slang component type enum to OptiX cooperative vector element type
__host__ __device__ constexpr OptixCoopVecElemType slangToOptixComponentType(unsigned slangEnum)
{
switch (slangEnum)
{
case 0:
return OPTIX_COOP_VEC_ELEM_TYPE_FLOAT8_E4M3; // FloatE4M3
case 1:
return OPTIX_COOP_VEC_ELEM_TYPE_FLOAT8_E5M2; // FloatE5M2
case 2:
return OPTIX_COOP_VEC_ELEM_TYPE_FLOAT16; // Float16
case 3:
return OPTIX_COOP_VEC_ELEM_TYPE_FLOAT32; // Float32
case 5:
return OPTIX_COOP_VEC_ELEM_TYPE_INT8; // SignedInt8
case 7:
return OPTIX_COOP_VEC_ELEM_TYPE_INT32; // SignedInt32
case 10:
return OPTIX_COOP_VEC_ELEM_TYPE_UINT8; // UnsignedInt8
case 12:
return OPTIX_COOP_VEC_ELEM_TYPE_UINT32; // UnsignedInt32
default:
return OPTIX_COOP_VEC_ELEM_TYPE_FLOAT32; // Default
}
}

// Constexpr function to map Slang matrix layout enum to OptiX cooperative vector matrix layout
__host__ __device__ constexpr OptixCoopVecMatrixLayout slangToOptixMatrixLayout(unsigned slangEnum)
{
switch (slangEnum)
{
case 0:
return OPTIX_COOP_VEC_MATRIX_LAYOUT_ROW_MAJOR; // RowMajor
case 1:
return OPTIX_COOP_VEC_MATRIX_LAYOUT_COLUMN_MAJOR; // ColumnMajor
case 2:
return OPTIX_COOP_VEC_MATRIX_LAYOUT_INFERENCING_OPTIMAL; // InferencingOptimal
case 3:
return OPTIX_COOP_VEC_MATRIX_LAYOUT_TRAINING_OPTIMAL; // TrainingOptimal
default:
return OPTIX_COOP_VEC_MATRIX_LAYOUT_ROW_MAJOR; // Default
}
}

// Wrapper structs to maintain compatibility with existing template-based interface
template<unsigned SlangEnum>
struct SlangToOptixComponentType
{
static constexpr OptixCoopVecElemType value = slangToOptixComponentType(SlangEnum);
};

template<unsigned SlangEnum>
struct SlangToOptixMatrixLayout
{
static constexpr OptixCoopVecMatrixLayout value = slangToOptixMatrixLayout(SlangEnum);
};

// Template trait to extract vector size from OptixCoopVec<T, N>
// Conditional compilation for NVRTC compatibility
template<typename T>
struct OptixCoopVecTraits;

// Template specialization for OptiX's OptixCoopVec - only enabled when cooperative vectors are
// available NVRTC explicitly disables cooperative vectors by setting
// OPTIX_INCLUDE_COOPERATIVE_VECTOR to 0
#if defined(OPTIX_VERSION) && OPTIX_VERSION > 90000
template<typename T, unsigned int N>
struct OptixCoopVecTraits<OptixCoopVec<T, N>>
{
static constexpr unsigned int size = N;
};
#endif

template<
typename VecTOut,
typename VecTIn,
unsigned inputInterpretation,
unsigned matrixInterpretation,
unsigned matrixLayout>
__forceinline__ __device__ VecTOut slangOptixCoopVecMatMul(
const VecTIn& inputVector,
CUdeviceptr matrix,
unsigned matrixOffset,
bool transpose,
unsigned matrixStride)
{
constexpr unsigned N = OptixCoopVecTraits<VecTOut>::size; // Output vector size
constexpr unsigned K = OptixCoopVecTraits<VecTIn>::size; // Input vector size

return optixCoopVecMatMul<
VecTOut,
VecTIn,
SlangToOptixComponentType<inputInterpretation>::value,
SlangToOptixMatrixLayout<matrixLayout>::value,
false,
N,
K,
SlangToOptixComponentType<matrixInterpretation>::value>(
inputVector,
matrix,
matrixOffset,
matrixStride);
}

// OptiX cooperative vector matrix multiplication wrapper (WITH bias - 6 runtime params)
template<
typename VecTOut,
typename VecTIn,
unsigned inputInterpretation,
unsigned matrixInterpretation,
unsigned matrixLayout,
unsigned biasInterpretation>
__forceinline__ __device__ VecTOut slangOptixCoopVecMatMul(
const VecTIn& inputVector,
CUdeviceptr matrix,
unsigned matrixOffset,
CUdeviceptr bias,
unsigned biasOffset,
unsigned matrixStride)
{
constexpr unsigned N = OptixCoopVecTraits<VecTOut>::size; // Output vector size
constexpr unsigned K = OptixCoopVecTraits<VecTIn>::size; // Input vector size

// Call OptiX SDK with bias (6 runtime parameters)
return optixCoopVecMatMul<
VecTOut,
VecTIn,
SlangToOptixComponentType<inputInterpretation>::value,
SlangToOptixMatrixLayout<matrixLayout>::value,
false,
N,
K,
SlangToOptixComponentType<matrixInterpretation>::value,
SlangToOptixComponentType<biasInterpretation>::value>(
inputVector,
matrix,
matrixOffset,
bias,
biasOffset,
matrixStride);
}

// OptiX cooperative vector matrix multiplication wrapper (WITHOUT bias, 4 runtime params -
// StructuredBuffer variant)
template<
typename VecTOut,
typename VecTIn,
unsigned inputInterpretation,
unsigned matrixInterpretation,
unsigned matrixLayout>
__forceinline__ __device__ VecTOut slangOptixCoopVecMatMul(
const VecTIn& inputVector,
CUdeviceptr matrix,
unsigned matrixOffset,
unsigned matrixStride)
{
constexpr unsigned N = OptixCoopVecTraits<VecTOut>::size; // Output vector size
constexpr unsigned K = OptixCoopVecTraits<VecTIn>::size; // Input vector size

// Call OptiX SDK without bias and without transpose (4 runtime parameters)
return optixCoopVecMatMul<
VecTOut,
VecTIn,
SlangToOptixComponentType<inputInterpretation>::value,
SlangToOptixMatrixLayout<matrixLayout>::value,
false,
N,
K,
SlangToOptixComponentType<matrixInterpretation>::value>(
inputVector,
matrix,
matrixOffset,
matrixStride);
}

#endif // SLANG_CUDA_ENABLE_OPTIX
Loading