Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion include/slang.h
Original file line number Diff line number Diff line change
Expand Up @@ -857,6 +857,23 @@ typedef uint32_t SlangSizeT;
SLANG_STAGE_PIXEL = SLANG_STAGE_FRAGMENT,
};
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Gap: New public API enums lack documentation comments

The neighboring enums in slang.h (e.g., SlangDebugInfoLevel, SlangOptimizationLevel, SlangStage) have documentation comments describing their purpose. These two new enums — SlangCooperativeMatrixUse and SlangCooperativeVectorMatrixLayout — are added without any doc comments. Since they're part of the public API consumed by downstream tools and language bindings, documentation would help users understand their role without needing to read the SPIR-V or HLSL specs.

Suggestion:

/// Specifies the role of a cooperative matrix in a multiply-accumulate operation.
typedef SlangUInt32 SlangCooperativeMatrixUseIntegral;
enum SlangCooperativeMatrixUse : SlangCooperativeMatrixUseIntegral
{
    SLANG_COOPERATIVE_MATRIX_USE_A,           ///< Left-hand matrix (A in D = A*B + C)
    SLANG_COOPERATIVE_MATRIX_USE_B,           ///< Right-hand matrix (B in D = A*B + C)
    SLANG_COOPERATIVE_MATRIX_USE_ACCUMULATOR, ///< Accumulator matrix (C or D in D = A*B + C)
};

/// Memory layout for cooperative vector matrix data in device memory.
typedef SlangUInt32 SlangCooperativeVectorMatrixLayoutIntegral;
enum SlangCooperativeVectorMatrixLayout : SlangCooperativeVectorMatrixLayoutIntegral
{
    SLANG_COOPERATIVE_VECTOR_MATRIX_LAYOUT_ROW_MAJOR,           ///< Standard row-major layout
    SLANG_COOPERATIVE_VECTOR_MATRIX_LAYOUT_COLUMN_MAJOR,        ///< Standard column-major layout
    SLANG_COOPERATIVE_VECTOR_MATRIX_LAYOUT_INFERENCING_OPTIMAL, ///< Layout optimized for inference
    SLANG_COOPERATIVE_VECTOR_MATRIX_LAYOUT_TRAINING_OPTIMAL,    ///< Layout optimized for training
};


typedef SlangUInt32 SlangCooperativeMatrixUseIntegral;
enum SlangCooperativeMatrixUse : SlangCooperativeMatrixUseIntegral
{
SLANG_COOPERATIVE_MATRIX_USE_A,
SLANG_COOPERATIVE_MATRIX_USE_B,
SLANG_COOPERATIVE_MATRIX_USE_ACCUMULATOR,
};

typedef SlangUInt32 SlangCooperativeVectorMatrixLayoutIntegral;
enum SlangCooperativeVectorMatrixLayout : SlangCooperativeVectorMatrixLayoutIntegral
{
SLANG_COOPERATIVE_VECTOR_MATRIX_LAYOUT_ROW_MAJOR,
SLANG_COOPERATIVE_VECTOR_MATRIX_LAYOUT_COLUMN_MAJOR,
SLANG_COOPERATIVE_VECTOR_MATRIX_LAYOUT_INFERENCING_OPTIMAL,
Comment thread
jkwak-work marked this conversation as resolved.
SLANG_COOPERATIVE_VECTOR_MATRIX_LAYOUT_TRAINING_OPTIMAL,
};
Comment thread
jkwak-work marked this conversation as resolved.

typedef SlangUInt32 SlangDebugInfoLevelIntegral;
enum SlangDebugInfoLevel : SlangDebugInfoLevelIntegral
{
Comment on lines 858 to 879
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Gap: New public API enums and scalar types lack inline documentation

SlangCooperativeMatrixUse and SlangCooperativeVectorMatrixLayout (here) plus the new scalar types SLANG_SCALAR_TYPE_BFLOAT16, SLANG_SCALAR_TYPE_FLOAT_E4M3, SLANG_SCALAR_TYPE_FLOAT_E5M2 (line 1988) are user-facing API additions in a public header consumed by downstream tools and bindings. Other enums in this header (e.g., SlangDebugInfoLevel) include documentation comments.

Suggestion: Add brief Doxygen-style comments:

/** Specifies which role a cooperative matrix plays in multiply-accumulate. */
typedef SlangUInt32 SlangCooperativeMatrixUseIntegral;
enum SlangCooperativeMatrixUse : SlangCooperativeMatrixUseIntegral
{
    SLANG_COOPERATIVE_MATRIX_USE_A,            ///< Left-hand matrix (A) in D = A*B+C
    SLANG_COOPERATIVE_MATRIX_USE_B,            ///< Right-hand matrix (B) in D = A*B+C
    SLANG_COOPERATIVE_MATRIX_USE_ACCUMULATOR,  ///< Accumulator matrix (C/D) in D = A*B+C
};

And for the scalar types:

SLANG_SCALAR_TYPE_BFLOAT16,    ///< Brain Float 16: 1 sign + 8 exp + 7 mantissa
SLANG_SCALAR_TYPE_FLOAT_E4M3,  ///< FP8 E4M3: 1 sign + 4 exp + 3 mantissa
SLANG_SCALAR_TYPE_FLOAT_E5M2,  ///< FP8 E5M2: 1 sign + 5 exp + 2 mantissa

Expand Down Expand Up @@ -1964,7 +1981,10 @@ public: \
SLANG_SCALAR_TYPE_INT16,
SLANG_SCALAR_TYPE_UINT16,
SLANG_SCALAR_TYPE_INTPTR,
SLANG_SCALAR_TYPE_UINTPTR
SLANG_SCALAR_TYPE_UINTPTR,
SLANG_SCALAR_TYPE_BFLOAT16,
SLANG_SCALAR_TYPE_FLOAT_E4M3,
SLANG_SCALAR_TYPE_FLOAT_E5M2,
};

// abstract decl reflection
Expand Down Expand Up @@ -2349,6 +2369,11 @@ struct TypeReflection
UInt8 = SLANG_SCALAR_TYPE_UINT8,
Int16 = SLANG_SCALAR_TYPE_INT16,
UInt16 = SLANG_SCALAR_TYPE_UINT16,
IntPtr = SLANG_SCALAR_TYPE_INTPTR,
UIntPtr = SLANG_SCALAR_TYPE_UINTPTR,
BFloat16 = SLANG_SCALAR_TYPE_BFLOAT16,
FloatE4M3 = SLANG_SCALAR_TYPE_FLOAT_E4M3,
FloatE5M2 = SLANG_SCALAR_TYPE_FLOAT_E5M2,
};
Comment thread
coderabbitai[bot] marked this conversation as resolved.

Kind getKind() { return (Kind)spReflectionType_GetKind((SlangReflectionType*)this); }
Expand Down
111 changes: 20 additions & 91 deletions prelude/slang-cuda-prelude.h
Original file line number Diff line number Diff line change
Expand Up @@ -6461,63 +6461,6 @@ _slang_waveClusteredRotate(bool4 value, unsigned int delta, unsigned int cluster

#if (OPTIX_VERSION >= 90000)

// Constexpr function to map Slang component type enum to OptiX cooperative vector element type
__host__ __device__ constexpr OptixCoopVecElemType slangToOptixComponentType(unsigned slangEnum)
{
switch (slangEnum)
{
case 0:
return OPTIX_COOP_VEC_ELEM_TYPE_FLOAT8_E4M3; // FloatE4M3
case 1:
return OPTIX_COOP_VEC_ELEM_TYPE_FLOAT8_E5M2; // FloatE5M2
case 2:
return OPTIX_COOP_VEC_ELEM_TYPE_FLOAT16; // Float16
case 3:
return OPTIX_COOP_VEC_ELEM_TYPE_FLOAT32; // Float32
case 5:
return OPTIX_COOP_VEC_ELEM_TYPE_INT8; // SignedInt8
case 7:
return OPTIX_COOP_VEC_ELEM_TYPE_INT32; // SignedInt32
case 10:
return OPTIX_COOP_VEC_ELEM_TYPE_UINT8; // UnsignedInt8
case 12:
return OPTIX_COOP_VEC_ELEM_TYPE_UINT32; // UnsignedInt32
default:
return OPTIX_COOP_VEC_ELEM_TYPE_FLOAT32; // Default
}
}

// Constexpr function to map Slang matrix layout enum to OptiX cooperative vector matrix layout
__host__ __device__ constexpr OptixCoopVecMatrixLayout slangToOptixMatrixLayout(unsigned slangEnum)
{
switch (slangEnum)
{
case 0:
return OPTIX_COOP_VEC_MATRIX_LAYOUT_ROW_MAJOR; // RowMajor
case 1:
return OPTIX_COOP_VEC_MATRIX_LAYOUT_COLUMN_MAJOR; // ColumnMajor
case 2:
return OPTIX_COOP_VEC_MATRIX_LAYOUT_INFERENCING_OPTIMAL; // InferencingOptimal
case 3:
return OPTIX_COOP_VEC_MATRIX_LAYOUT_TRAINING_OPTIMAL; // TrainingOptimal
default:
return OPTIX_COOP_VEC_MATRIX_LAYOUT_ROW_MAJOR; // Default
}
}

// Wrapper structs to maintain compatibility with existing template-based interface
template<unsigned SlangEnum>
struct SlangToOptixComponentType
{
static constexpr OptixCoopVecElemType value = slangToOptixComponentType(SlangEnum);
};

template<unsigned SlangEnum>
struct SlangToOptixMatrixLayout
{
static constexpr OptixCoopVecMatrixLayout value = slangToOptixMatrixLayout(SlangEnum);
};

// Template trait to extract vector size from OptixCoopVec<T, N>
// Conditional compilation for NVRTC compatibility
template<typename T>
Expand All @@ -6537,9 +6480,9 @@ struct OptixCoopVecTraits<OptixCoopVec<T, N>>
template<
typename VecTOut,
typename VecTIn,
unsigned inputInterpretation,
unsigned matrixInterpretation,
unsigned matrixLayout>
OptixCoopVecElemType inputInterpretation,
OptixCoopVecElemType matrixInterpretation,
OptixCoopVecMatrixLayout matrixLayout>
Comment thread
coderabbitai[bot] marked this conversation as resolved.
__forceinline__ __device__ VecTOut slangOptixCoopVecMatMul(
const VecTIn& inputVector,
CUdeviceptr matrix,
Expand All @@ -6553,26 +6496,22 @@ __forceinline__ __device__ VecTOut slangOptixCoopVecMatMul(
return optixCoopVecMatMul<
VecTOut,
VecTIn,
SlangToOptixComponentType<inputInterpretation>::value,
SlangToOptixMatrixLayout<matrixLayout>::value,
inputInterpretation,
matrixLayout,
false,
Comment thread
jkwak-work marked this conversation as resolved.
N,
K,
SlangToOptixComponentType<matrixInterpretation>::value>(
inputVector,
matrix,
matrixOffset,
matrixStride);
matrixInterpretation>(inputVector, matrix, matrixOffset, matrixStride);
Comment thread
jkwak-work marked this conversation as resolved.
}

// OptiX cooperative vector matrix multiplication wrapper (WITH bias - 6 runtime params)
template<
typename VecTOut,
typename VecTIn,
unsigned inputInterpretation,
unsigned matrixInterpretation,
unsigned matrixLayout,
unsigned biasInterpretation>
OptixCoopVecElemType inputInterpretation,
OptixCoopVecElemType matrixInterpretation,
OptixCoopVecMatrixLayout matrixLayout,
OptixCoopVecElemType biasInterpretation>
__forceinline__ __device__ VecTOut slangOptixCoopVecMatMul(
const VecTIn& inputVector,
CUdeviceptr matrix,
Expand All @@ -6588,29 +6527,23 @@ __forceinline__ __device__ VecTOut slangOptixCoopVecMatMul(
return optixCoopVecMatMul<
VecTOut,
VecTIn,
SlangToOptixComponentType<inputInterpretation>::value,
SlangToOptixMatrixLayout<matrixLayout>::value,
inputInterpretation,
matrixLayout,
false,
N,
K,
SlangToOptixComponentType<matrixInterpretation>::value,
SlangToOptixComponentType<biasInterpretation>::value>(
inputVector,
matrix,
matrixOffset,
bias,
biasOffset,
matrixStride);
matrixInterpretation,
biasInterpretation>(inputVector, matrix, matrixOffset, bias, biasOffset, matrixStride);
}

// OptiX cooperative vector matrix multiplication wrapper (WITHOUT bias, 4 runtime params -
// StructuredBuffer variant)
template<
typename VecTOut,
typename VecTIn,
unsigned inputInterpretation,
unsigned matrixInterpretation,
unsigned matrixLayout>
OptixCoopVecElemType inputInterpretation,
OptixCoopVecElemType matrixInterpretation,
OptixCoopVecMatrixLayout matrixLayout>
__forceinline__ __device__ VecTOut slangOptixCoopVecMatMul(
const VecTIn& inputVector,
CUdeviceptr matrix,
Expand All @@ -6624,16 +6557,12 @@ __forceinline__ __device__ VecTOut slangOptixCoopVecMatMul(
return optixCoopVecMatMul<
VecTOut,
VecTIn,
SlangToOptixComponentType<inputInterpretation>::value,
SlangToOptixMatrixLayout<matrixLayout>::value,
inputInterpretation,
matrixLayout,
false,
N,
K,
SlangToOptixComponentType<matrixInterpretation>::value>(
inputVector,
matrix,
matrixOffset,
matrixStride);
matrixInterpretation>(inputVector, matrix, matrixOffset, matrixStride);
}

#endif // (OPTIX_VERSION >= 90000)
Expand Down
7 changes: 6 additions & 1 deletion source/core/slang-type-text-util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@ namespace
x(Int64, int64_t) \
x(UInt64, uint64_t) \
x(Float32, float) \
x(Float64, double)
x(Float64, double) \
x(IntPtr, intptr_t) \
x(UIntPtr, uintptr_t) \
x(BFloat16, bfloat16) \
x(FloatE4M3, float_e4m3) \
x(FloatE5M2, float_e5m2)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
// clang-format on

struct ScalarTypeInfo
Expand Down
7 changes: 6 additions & 1 deletion source/slang-wasm/slang-wasm-bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,12 @@ EMSCRIPTEN_BINDINGS(slang)
.value("Int8", slang::TypeReflection::ScalarType::Int8)
.value("UInt8", slang::TypeReflection::ScalarType::UInt8)
.value("Int16", slang::TypeReflection::ScalarType::Int16)
.value("UInt16", slang::TypeReflection::ScalarType::UInt16);
.value("UInt16", slang::TypeReflection::ScalarType::UInt16)
.value("IntPtr", slang::TypeReflection::ScalarType::IntPtr)
.value("UIntPtr", slang::TypeReflection::ScalarType::UIntPtr)
.value("BFloat16", slang::TypeReflection::ScalarType::BFloat16)
.value("FloatE4M3", slang::TypeReflection::ScalarType::FloatE4M3)
.value("FloatE5M2", slang::TypeReflection::ScalarType::FloatE5M2);

class_<slang::wgsl::TypeReflection>("TypeReflection")
.function("getScalarType", &slang::wgsl::TypeReflection::getScalarType)
Expand Down
Loading
Loading