Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 139 additions & 0 deletions include/slang.h
Original file line number Diff line number Diff line change
Expand Up @@ -857,6 +857,15 @@ typedef uint32_t SlangSizeT;
SLANG_STAGE_PIXEL = SLANG_STAGE_FRAGMENT,
};

typedef SlangUInt32 SlangScopeIntegral;
enum SlangScope : SlangScopeIntegral
Comment thread
jkwak-work marked this conversation as resolved.
Comment thread
jkwak-work marked this conversation as resolved.
{
SLANG_SCOPE_NONE,
SLANG_SCOPE_THREAD,
SLANG_SCOPE_WAVE,
Comment thread
jkwak-work marked this conversation as resolved.
Comment thread
jkwak-work marked this conversation as resolved.
Comment thread
jkwak-work marked this conversation as resolved.
Comment on lines +860 to +865
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Gap: New public enum SlangScope lacks documentation

This is a new public API enum that maps GPU execution scopes for cooperative operations. Unlike the struct fields and the ICooperativeTypesMetadata interface (which have doc comments), the enum values have no documentation. Users need to understand the mapping to GPU concepts (thread = invocation, wave = subgroup, thread group = workgroup) and how these correspond to Vulkan/HLSL scopes.

Suggestion: Add a brief doc comment and per-value annotations:

/** Cooperative type execution scope.
 * Specifies the scope of threads that cooperate in matrix/vector operations.
 */
typedef SlangUInt32 SlangScopeIntegral;
enum SlangScope : SlangScopeIntegral
{
    SLANG_SCOPE_NONE,         ///< Invalid/unspecified scope
    SLANG_SCOPE_THREAD,       ///< Single invocation
    SLANG_SCOPE_WAVE,         ///< Subgroup/wave scope
    SLANG_SCOPE_THREAD_GROUP, ///< Workgroup/thread group scope
};

SLANG_SCOPE_THREAD_GROUP,
Comment thread
jkwak-work marked this conversation as resolved.
};
Comment thread
jkwak-work marked this conversation as resolved.
Comment thread
jkwak-work marked this conversation as resolved.
Comment on lines +861 to +867
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this same to the one we already have?

// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_scope_id
// must be 32 bit to match SPIR-V
enum class MemoryScope : int32_t
{
CrossDevice = 0,
Device = 1,
Workgroup = 2,
Subgroup = 3,
Invocation = 4,
QueueFamily = 5,
ShaderCall = 6,
};

LLMs are also point problems on this enum.
I think the enum name should be more like MemoryScope rather than just Scope.
And the values should match to the internal and SPIRV values.


Comment thread
jkwak-work marked this conversation as resolved.
typedef SlangUInt32 SlangCooperativeMatrixUseIntegral;
enum SlangCooperativeMatrixUse : SlangCooperativeMatrixUseIntegral
{
Expand Down Expand Up @@ -4488,6 +4497,136 @@ struct IMetadata : public ISlangCastable
};
#define SLANG_UUID_IMetadata IMetadata::getTypeGuid()

struct CooperativeMatrixType
Comment thread
jkwak-work marked this conversation as resolved.
{
// Component type `NONE` means this type is not valid.
SlangScalarType componentType = SLANG_SCALAR_TYPE_NONE;
SlangScope scope = SLANG_SCOPE_NONE;

uint32_t rowCount = 0;
uint32_t columnCount = 0;
Comment thread
cmarcelo marked this conversation as resolved.

Comment thread
jkwak-work marked this conversation as resolved.
SlangCooperativeMatrixUse use = SLANG_COOPERATIVE_MATRIX_USE_A;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔵 Question: Default .use = SLANG_COOPERATIVE_MATRIX_USE_A for invalid types

All other fields in CooperativeMatrixType default to zero/NONE to indicate "not valid," but .use defaults to SLANG_COOPERATIVE_MATRIX_USE_A (value 0) because the SlangCooperativeMatrixUse enum has no NONE sentinel. This means a default-constructed or partially-initialized struct has a field that looks valid (USE_A) even when the struct is logically invalid (componentType == NONE).

The doc comment says componentType == NONE is the validity check, so this is consistent — but it could confuse callers who inspect .use without first checking componentType. Is this intentional, or would it be worth adding a SLANG_COOPERATIVE_MATRIX_USE_NONE sentinel to the existing enum (recognizing that changes the enum layout)?


bool operator==(const CooperativeMatrixType& other) const
{
return componentType == other.componentType && scope == other.scope &&
rowCount == other.rowCount && columnCount == other.columnCount && use == other.use;
}
};

struct CooperativeMatrixCombination
{
// Number of rows of matrix A and the result.
uint32_t m = 0;
Comment thread
jkwak-work marked this conversation as resolved.
// Number of columns of matrix B and the result.
uint32_t n = 0;
// Shared inner dimension: columns of A and rows of B.
uint32_t k = 0;
Comment thread
jkwak-work marked this conversation as resolved.

SlangScalarType componentTypeA = SLANG_SCALAR_TYPE_NONE;
SlangScalarType componentTypeB = SLANG_SCALAR_TYPE_NONE;
SlangScalarType componentTypeC = SLANG_SCALAR_TYPE_NONE;
SlangScalarType componentTypeResult = SLANG_SCALAR_TYPE_NONE;

SlangBool saturate = false;
SlangScope scope = SLANG_SCOPE_NONE;

bool operator==(const CooperativeMatrixCombination& other) const
Comment thread
jkwak-work marked this conversation as resolved.
{
return m == other.m && n == other.n && k == other.k &&
componentTypeA == other.componentTypeA && componentTypeB == other.componentTypeB &&
componentTypeC == other.componentTypeC &&
componentTypeResult == other.componentTypeResult && saturate == other.saturate &&
scope == other.scope;
}
};

struct CooperativeVectorType
{
Comment thread
jkwak-work marked this conversation as resolved.
SlangScalarType componentType = SLANG_SCALAR_TYPE_NONE;

// Maximum element count for this component type across all cooperative
// vectors in the shader. Zero means the type appears only in training
// operations; non-zero does not preclude `usedForTrainingOp` being true.
uint32_t maxSize = 0;
Comment thread
jkwak-work marked this conversation as resolved.

// Indicates this component type is used as the accumulation/storage type for cooperative
// vector training operations such as outer-product accumulation and reduce-sum
// accumulation.
SlangBool usedForTrainingOp = false;
};
Comment thread
jkwak-work marked this conversation as resolved.

struct CooperativeVectorCombination
{
Comment thread
jkwak-work marked this conversation as resolved.
SlangScalarType inputType = SLANG_SCALAR_TYPE_NONE;
Comment thread
jkwak-work marked this conversation as resolved.
SlangScalarType inputInterpretation = SLANG_SCALAR_TYPE_NONE;
// Number of logical elements packed into each physical input element.
// For example, this is 4 when four int8 values are packed into one uint32 input element.
uint32_t inputPackingFactor = 1;
Comment thread
jkwak-work marked this conversation as resolved.
SlangScalarType matrixInterpretation = SLANG_SCALAR_TYPE_NONE;
// `NONE` means the operation has no bias operand/matrix.
SlangScalarType biasInterpretation = SLANG_SCALAR_TYPE_NONE;
SlangScalarType resultType = SLANG_SCALAR_TYPE_NONE;
SlangBool transpose = false;

bool operator==(const CooperativeVectorCombination& other) const
{
return inputType == other.inputType && inputInterpretation == other.inputInterpretation &&
inputPackingFactor == other.inputPackingFactor &&
matrixInterpretation == other.matrixInterpretation &&
biasInterpretation == other.biasInterpretation && resultType == other.resultType &&
transpose == other.transpose;
}
};

/** Cooperative matrix and vector metadata.

This interface reports the cooperative matrix/vector type information that a compiled target uses,
including both types and certain type combinations required to execute some operations (like matrix
multiplication).

Applications can use this metadata to compare shader requirements against the capabilities exposed
by the target API/driver (for example Vulkan cooperative matrix/vector property queries, or
analogous APIs on other backends).

Metadata is collected from the IR after target-specific lowering, so it only reflects cooperative
types that survive as native constructs in the final output. Targets that lower cooperative types
into ordinary arrays will report empty lists.

Lists are exposed using `get*Count()` plus `get*ByIndex()` methods, where the count returns the
number of elements currently available and valid indices are in the range `[0, count)`.

Cast from an `IMetadata*` using `castAs()`.
*/
Comment thread
coderabbitai[bot] marked this conversation as resolved.
struct ICooperativeTypesMetadata : public ISlangCastable
{
SLANG_COM_INTERFACE(
0x64c4d536,
0xd949,
0x49c3,
{0x9f, 0xde, 0x3f, 0x0f, 0x9c, 0x6f, 0x01, 0x31})

virtual SlangUInt SLANG_MCALL getCooperativeMatrixTypeCount() = 0;
Comment thread
jkwak-work marked this conversation as resolved.
virtual SlangResult SLANG_MCALL
getCooperativeMatrixTypeByIndex(SlangUInt index, CooperativeMatrixType* outType) = 0;
Comment thread
jkwak-work marked this conversation as resolved.

virtual SlangUInt SLANG_MCALL getCooperativeMatrixCombinationCount() = 0;
virtual SlangResult SLANG_MCALL getCooperativeMatrixCombinationByIndex(
SlangUInt index,
CooperativeMatrixCombination* outCombination) = 0;

virtual SlangUInt SLANG_MCALL getCooperativeVectorTypeCount() = 0;
virtual SlangResult SLANG_MCALL
getCooperativeVectorTypeByIndex(SlangUInt index, CooperativeVectorType* outType) = 0;

virtual SlangUInt SLANG_MCALL getCooperativeVectorCombinationCount() = 0;
virtual SlangResult SLANG_MCALL getCooperativeVectorCombinationByIndex(
SlangUInt index,
CooperativeVectorCombination* outCombination) = 0;
};
#define SLANG_UUID_ICooperativeTypesMetadata ICooperativeTypesMetadata::getTypeGuid()

/** Compile result for storing and retrieving multiple output blobs.
This is needed for features such as separate debug compilation which
output both base and debug spirv.
Expand Down
80 changes: 80 additions & 0 deletions source/compiler-core/slang-artifact-associated-impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,10 @@ void* ArtifactPostEmitMetadata::getInterface(const Guid& guid)
{
return static_cast<IArtifactPostEmitMetadata*>(this);
}
if (guid == slang::IMetadata::getTypeGuid())
return static_cast<slang::IMetadata*>(this);
if (guid == slang::ICooperativeTypesMetadata::getTypeGuid())
Comment thread
jkwak-work marked this conversation as resolved.
return static_cast<slang::ICooperativeTypesMetadata*>(this);
return nullptr;
}

Expand Down Expand Up @@ -344,5 +348,81 @@ const char* ArtifactPostEmitMetadata::getDebugBuildIdentifier()
return m_debugBuildIdentifier.getBuffer();
}

SlangUInt ArtifactPostEmitMetadata::getCooperativeMatrixTypeCount()
{
return SlangUInt(m_cooperativeMatrixTypes.getCount());
}

SlangResult ArtifactPostEmitMetadata::getCooperativeMatrixTypeByIndex(
SlangUInt index,
slang::CooperativeMatrixType* outType)
{
if (!outType)
return SLANG_E_INVALID_ARG;

if (index >= SlangUInt(m_cooperativeMatrixTypes.getCount()))
return SLANG_E_INVALID_ARG;

*outType = m_cooperativeMatrixTypes[Index(index)];
return SLANG_OK;
}

SlangUInt ArtifactPostEmitMetadata::getCooperativeMatrixCombinationCount()
{
return SlangUInt(m_cooperativeMatrixCombinations.getCount());
}

SlangResult ArtifactPostEmitMetadata::getCooperativeMatrixCombinationByIndex(
SlangUInt index,
slang::CooperativeMatrixCombination* outCombination)
{
if (!outCombination)
return SLANG_E_INVALID_ARG;

if (index >= SlangUInt(m_cooperativeMatrixCombinations.getCount()))
return SLANG_E_INVALID_ARG;

*outCombination = m_cooperativeMatrixCombinations[Index(index)];
return SLANG_OK;
}

SlangUInt ArtifactPostEmitMetadata::getCooperativeVectorTypeCount()
{
return SlangUInt(m_cooperativeVectorTypes.getCount());
}

SlangResult ArtifactPostEmitMetadata::getCooperativeVectorTypeByIndex(
SlangUInt index,
slang::CooperativeVectorType* outType)
{
if (!outType)
return SLANG_E_INVALID_ARG;

if (index >= SlangUInt(m_cooperativeVectorTypes.getCount()))
return SLANG_E_INVALID_ARG;

*outType = m_cooperativeVectorTypes[Index(index)];
return SLANG_OK;
}

SlangUInt ArtifactPostEmitMetadata::getCooperativeVectorCombinationCount()
{
return SlangUInt(m_cooperativeVectorCombinations.getCount());
}

SlangResult ArtifactPostEmitMetadata::getCooperativeVectorCombinationByIndex(
SlangUInt index,
slang::CooperativeVectorCombination* outCombination)
{
if (!outCombination)
return SLANG_E_INVALID_ARG;

if (index >= SlangUInt(m_cooperativeVectorCombinations.getCount()))
return SLANG_E_INVALID_ARG;

*outCombination = m_cooperativeVectorCombinations[Index(index)];
return SLANG_OK;
}


} // namespace Slang
28 changes: 27 additions & 1 deletion source/compiler-core/slang-artifact-associated-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,9 @@ struct ShaderBindingRange
}
};

class ArtifactPostEmitMetadata : public ComBaseObject, public IArtifactPostEmitMetadata
class ArtifactPostEmitMetadata : public ComBaseObject,
public IArtifactPostEmitMetadata,
public slang::ICooperativeTypesMetadata
{
public:
typedef ArtifactPostEmitMetadata ThisType;
Expand All @@ -201,6 +203,26 @@ class ArtifactPostEmitMetadata : public ComBaseObject, public IArtifactPostEmitM

SLANG_NO_THROW virtual const char* SLANG_MCALL getDebugBuildIdentifier() SLANG_OVERRIDE;

// ICooperativeTypesMetadata
SLANG_NO_THROW virtual SlangUInt SLANG_MCALL getCooperativeMatrixTypeCount() SLANG_OVERRIDE;
SLANG_NO_THROW virtual SlangResult SLANG_MCALL getCooperativeMatrixTypeByIndex(
SlangUInt index,
slang::CooperativeMatrixType* outType) SLANG_OVERRIDE;
SLANG_NO_THROW virtual SlangUInt SLANG_MCALL getCooperativeMatrixCombinationCount()
SLANG_OVERRIDE;
SLANG_NO_THROW virtual SlangResult SLANG_MCALL getCooperativeMatrixCombinationByIndex(
SlangUInt index,
slang::CooperativeMatrixCombination* outCombination) SLANG_OVERRIDE;
SLANG_NO_THROW virtual SlangUInt SLANG_MCALL getCooperativeVectorTypeCount() SLANG_OVERRIDE;
SLANG_NO_THROW virtual SlangResult SLANG_MCALL getCooperativeVectorTypeByIndex(
SlangUInt index,
slang::CooperativeVectorType* outType) SLANG_OVERRIDE;
SLANG_NO_THROW virtual SlangUInt SLANG_MCALL getCooperativeVectorCombinationCount()
SLANG_OVERRIDE;
SLANG_NO_THROW virtual SlangResult SLANG_MCALL getCooperativeVectorCombinationByIndex(
SlangUInt index,
slang::CooperativeVectorCombination* outCombination) SLANG_OVERRIDE;

void* getInterface(const Guid& uuid);
void* getObject(const Guid& uuid);

Expand All @@ -211,6 +233,10 @@ class ArtifactPostEmitMetadata : public ComBaseObject, public IArtifactPostEmitM

List<ShaderBindingRange> m_usedBindings;
List<String> m_exportedFunctionMangledNames;
List<slang::CooperativeMatrixType> m_cooperativeMatrixTypes;
List<slang::CooperativeMatrixCombination> m_cooperativeMatrixCombinations;
List<slang::CooperativeVectorType> m_cooperativeVectorTypes;
List<slang::CooperativeVectorCombination> m_cooperativeVectorCombinations;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if List is a right container for this when you want to maintain the sorted order.
Although the name List sounds like a linked-list but it is more like std::vector.
The cost of inserting in the middle will be higher than you may expected.

String m_debugBuildIdentifier;
};

Expand Down
4 changes: 4 additions & 0 deletions source/slang/slang-emit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2262,6 +2262,10 @@ Result linkAndOptimizeIR(
auto metadata = new ArtifactPostEmitMetadata;
outLinkedIR.metadata = metadata;

Comment thread
jkwak-work marked this conversation as resolved.
// Runs after target-specific lowering so it only captures cooperative types that remain
// as native constructs visible to the driver (see ICooperativeTypesMetadata docs).
SLANG_PASS(collectCooperativeMetadata, *metadata);
Comment on lines +2265 to +2267
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can skip this pass when the cooperative capability is not used or cooperative-vector or cooperative-matrix types are not found earlier.


if (targetProgram->getOptionSet().getBoolOption(CompilerOptionName::EmbedDownstreamIR))
{
SLANG_PASS(unexportNonEmbeddableIR, target);
Expand Down
Loading
Loading