-
Notifications
You must be signed in to change notification settings - Fork 437
Add API to list coopMat/coopVec types and combinations #10076
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
2b426b1
2165e62
8968df2
b7738c8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -857,6 +857,15 @@ typedef uint32_t SlangSizeT; | |||||||||||||||||||||||||
| SLANG_STAGE_PIXEL = SLANG_STAGE_FRAGMENT, | ||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| typedef SlangUInt32 SlangScopeIntegral; | ||||||||||||||||||||||||||
| enum SlangScope : SlangScopeIntegral | ||||||||||||||||||||||||||
|
jkwak-work marked this conversation as resolved.
|
||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||
| SLANG_SCOPE_NONE, | ||||||||||||||||||||||||||
| SLANG_SCOPE_THREAD, | ||||||||||||||||||||||||||
| SLANG_SCOPE_WAVE, | ||||||||||||||||||||||||||
|
jkwak-work marked this conversation as resolved.
jkwak-work marked this conversation as resolved.
jkwak-work marked this conversation as resolved.
Comment on lines
+860
to
+865
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟡 Gap: New public enum This is a new public API enum that maps GPU execution scopes for cooperative operations. Unlike the struct fields and the Suggestion: Add a brief doc comment and per-value annotations: /** Cooperative type execution scope.
* Specifies the scope of threads that cooperate in matrix/vector operations.
*/
typedef SlangUInt32 SlangScopeIntegral;
enum SlangScope : SlangScopeIntegral
{
SLANG_SCOPE_NONE, ///< Invalid/unspecified scope
SLANG_SCOPE_THREAD, ///< Single invocation
SLANG_SCOPE_WAVE, ///< Subgroup/wave scope
SLANG_SCOPE_THREAD_GROUP, ///< Workgroup/thread group scope
}; |
||||||||||||||||||||||||||
| SLANG_SCOPE_THREAD_GROUP, | ||||||||||||||||||||||||||
|
jkwak-work marked this conversation as resolved.
|
||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||
|
jkwak-work marked this conversation as resolved.
jkwak-work marked this conversation as resolved.
Comment on lines
+861
to
+867
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this same to the one we already have? slang/source/slang/slang-type-system-shared.h Lines 119 to 130 in e4d533d
LLMs are also point problems on this enum. |
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
jkwak-work marked this conversation as resolved.
|
||||||||||||||||||||||||||
| typedef SlangUInt32 SlangCooperativeMatrixUseIntegral; | ||||||||||||||||||||||||||
| enum SlangCooperativeMatrixUse : SlangCooperativeMatrixUseIntegral | ||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||
|
|
@@ -4488,6 +4497,136 @@ struct IMetadata : public ISlangCastable | |||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||
| #define SLANG_UUID_IMetadata IMetadata::getTypeGuid() | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| struct CooperativeMatrixType | ||||||||||||||||||||||||||
|
jkwak-work marked this conversation as resolved.
|
||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||
| // Component type `NONE` means this type is not valid. | ||||||||||||||||||||||||||
| SlangScalarType componentType = SLANG_SCALAR_TYPE_NONE; | ||||||||||||||||||||||||||
| SlangScope scope = SLANG_SCOPE_NONE; | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| uint32_t rowCount = 0; | ||||||||||||||||||||||||||
| uint32_t columnCount = 0; | ||||||||||||||||||||||||||
|
cmarcelo marked this conversation as resolved.
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
jkwak-work marked this conversation as resolved.
|
||||||||||||||||||||||||||
| SlangCooperativeMatrixUse use = SLANG_COOPERATIVE_MATRIX_USE_A; | ||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🔵 Question: Default All other fields in The doc comment says |
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| bool operator==(const CooperativeMatrixType& other) const | ||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||
| return componentType == other.componentType && scope == other.scope && | ||||||||||||||||||||||||||
| rowCount == other.rowCount && columnCount == other.columnCount && use == other.use; | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| struct CooperativeMatrixCombination | ||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||
| // Number of rows of matrix A and the result. | ||||||||||||||||||||||||||
| uint32_t m = 0; | ||||||||||||||||||||||||||
|
jkwak-work marked this conversation as resolved.
|
||||||||||||||||||||||||||
| // Number of columns of matrix B and the result. | ||||||||||||||||||||||||||
| uint32_t n = 0; | ||||||||||||||||||||||||||
| // Shared inner dimension: columns of A and rows of B. | ||||||||||||||||||||||||||
| uint32_t k = 0; | ||||||||||||||||||||||||||
|
jkwak-work marked this conversation as resolved.
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| SlangScalarType componentTypeA = SLANG_SCALAR_TYPE_NONE; | ||||||||||||||||||||||||||
| SlangScalarType componentTypeB = SLANG_SCALAR_TYPE_NONE; | ||||||||||||||||||||||||||
| SlangScalarType componentTypeC = SLANG_SCALAR_TYPE_NONE; | ||||||||||||||||||||||||||
| SlangScalarType componentTypeResult = SLANG_SCALAR_TYPE_NONE; | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| SlangBool saturate = false; | ||||||||||||||||||||||||||
| SlangScope scope = SLANG_SCOPE_NONE; | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| bool operator==(const CooperativeMatrixCombination& other) const | ||||||||||||||||||||||||||
|
jkwak-work marked this conversation as resolved.
|
||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||
| return m == other.m && n == other.n && k == other.k && | ||||||||||||||||||||||||||
| componentTypeA == other.componentTypeA && componentTypeB == other.componentTypeB && | ||||||||||||||||||||||||||
| componentTypeC == other.componentTypeC && | ||||||||||||||||||||||||||
| componentTypeResult == other.componentTypeResult && saturate == other.saturate && | ||||||||||||||||||||||||||
| scope == other.scope; | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| struct CooperativeVectorType | ||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||
|
jkwak-work marked this conversation as resolved.
|
||||||||||||||||||||||||||
| SlangScalarType componentType = SLANG_SCALAR_TYPE_NONE; | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| // Maximum element count for this component type across all cooperative | ||||||||||||||||||||||||||
| // vectors in the shader. Zero means the type appears only in training | ||||||||||||||||||||||||||
| // operations; non-zero does not preclude `usedForTrainingOp` being true. | ||||||||||||||||||||||||||
| uint32_t maxSize = 0; | ||||||||||||||||||||||||||
|
jkwak-work marked this conversation as resolved.
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| // Indicates this component type is used as the accumulation/storage type for cooperative | ||||||||||||||||||||||||||
| // vector training operations such as outer-product accumulation and reduce-sum | ||||||||||||||||||||||||||
| // accumulation. | ||||||||||||||||||||||||||
| SlangBool usedForTrainingOp = false; | ||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||
|
jkwak-work marked this conversation as resolved.
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| struct CooperativeVectorCombination | ||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||
|
jkwak-work marked this conversation as resolved.
|
||||||||||||||||||||||||||
| SlangScalarType inputType = SLANG_SCALAR_TYPE_NONE; | ||||||||||||||||||||||||||
|
jkwak-work marked this conversation as resolved.
|
||||||||||||||||||||||||||
| SlangScalarType inputInterpretation = SLANG_SCALAR_TYPE_NONE; | ||||||||||||||||||||||||||
| // Number of logical elements packed into each physical input element. | ||||||||||||||||||||||||||
| // For example, this is 4 when four int8 values are packed into one uint32 input element. | ||||||||||||||||||||||||||
| uint32_t inputPackingFactor = 1; | ||||||||||||||||||||||||||
|
jkwak-work marked this conversation as resolved.
|
||||||||||||||||||||||||||
| SlangScalarType matrixInterpretation = SLANG_SCALAR_TYPE_NONE; | ||||||||||||||||||||||||||
| // `NONE` means the operation has no bias operand/matrix. | ||||||||||||||||||||||||||
| SlangScalarType biasInterpretation = SLANG_SCALAR_TYPE_NONE; | ||||||||||||||||||||||||||
| SlangScalarType resultType = SLANG_SCALAR_TYPE_NONE; | ||||||||||||||||||||||||||
| SlangBool transpose = false; | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| bool operator==(const CooperativeVectorCombination& other) const | ||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||
| return inputType == other.inputType && inputInterpretation == other.inputInterpretation && | ||||||||||||||||||||||||||
| inputPackingFactor == other.inputPackingFactor && | ||||||||||||||||||||||||||
| matrixInterpretation == other.matrixInterpretation && | ||||||||||||||||||||||||||
| biasInterpretation == other.biasInterpretation && resultType == other.resultType && | ||||||||||||||||||||||||||
| transpose == other.transpose; | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| /** Cooperative matrix and vector metadata. | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| This interface reports the cooperative matrix/vector type information that a compiled target uses, | ||||||||||||||||||||||||||
| including both types and certain type combinations required to execute some operations (like matrix | ||||||||||||||||||||||||||
| multiplication). | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| Applications can use this metadata to compare shader requirements against the capabilities exposed | ||||||||||||||||||||||||||
| by the target API/driver (for example Vulkan cooperative matrix/vector property queries, or | ||||||||||||||||||||||||||
| analogous APIs on other backends). | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| Metadata is collected from the IR after target-specific lowering, so it only reflects cooperative | ||||||||||||||||||||||||||
| types that survive as native constructs in the final output. Targets that lower cooperative types | ||||||||||||||||||||||||||
| into ordinary arrays will report empty lists. | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| Lists are exposed using `get*Count()` plus `get*ByIndex()` methods, where the count returns the | ||||||||||||||||||||||||||
| number of elements currently available and valid indices are in the range `[0, count)`. | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| Cast from an `IMetadata*` using `castAs()`. | ||||||||||||||||||||||||||
| */ | ||||||||||||||||||||||||||
|
coderabbitai[bot] marked this conversation as resolved.
|
||||||||||||||||||||||||||
| struct ICooperativeTypesMetadata : public ISlangCastable | ||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||
| SLANG_COM_INTERFACE( | ||||||||||||||||||||||||||
| 0x64c4d536, | ||||||||||||||||||||||||||
| 0xd949, | ||||||||||||||||||||||||||
| 0x49c3, | ||||||||||||||||||||||||||
| {0x9f, 0xde, 0x3f, 0x0f, 0x9c, 0x6f, 0x01, 0x31}) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| virtual SlangUInt SLANG_MCALL getCooperativeMatrixTypeCount() = 0; | ||||||||||||||||||||||||||
|
jkwak-work marked this conversation as resolved.
|
||||||||||||||||||||||||||
| virtual SlangResult SLANG_MCALL | ||||||||||||||||||||||||||
| getCooperativeMatrixTypeByIndex(SlangUInt index, CooperativeMatrixType* outType) = 0; | ||||||||||||||||||||||||||
|
jkwak-work marked this conversation as resolved.
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| virtual SlangUInt SLANG_MCALL getCooperativeMatrixCombinationCount() = 0; | ||||||||||||||||||||||||||
| virtual SlangResult SLANG_MCALL getCooperativeMatrixCombinationByIndex( | ||||||||||||||||||||||||||
| SlangUInt index, | ||||||||||||||||||||||||||
| CooperativeMatrixCombination* outCombination) = 0; | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| virtual SlangUInt SLANG_MCALL getCooperativeVectorTypeCount() = 0; | ||||||||||||||||||||||||||
| virtual SlangResult SLANG_MCALL | ||||||||||||||||||||||||||
| getCooperativeVectorTypeByIndex(SlangUInt index, CooperativeVectorType* outType) = 0; | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| virtual SlangUInt SLANG_MCALL getCooperativeVectorCombinationCount() = 0; | ||||||||||||||||||||||||||
| virtual SlangResult SLANG_MCALL getCooperativeVectorCombinationByIndex( | ||||||||||||||||||||||||||
| SlangUInt index, | ||||||||||||||||||||||||||
| CooperativeVectorCombination* outCombination) = 0; | ||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||
| #define SLANG_UUID_ICooperativeTypesMetadata ICooperativeTypesMetadata::getTypeGuid() | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| /** Compile result for storing and retrieving multiple output blobs. | ||||||||||||||||||||||||||
| This is needed for features such as separate debug compilation which | ||||||||||||||||||||||||||
| output both base and debug spirv. | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -174,7 +174,9 @@ struct ShaderBindingRange | |
| } | ||
| }; | ||
|
|
||
| class ArtifactPostEmitMetadata : public ComBaseObject, public IArtifactPostEmitMetadata | ||
| class ArtifactPostEmitMetadata : public ComBaseObject, | ||
| public IArtifactPostEmitMetadata, | ||
| public slang::ICooperativeTypesMetadata | ||
| { | ||
| public: | ||
| typedef ArtifactPostEmitMetadata ThisType; | ||
|
|
@@ -201,6 +203,26 @@ class ArtifactPostEmitMetadata : public ComBaseObject, public IArtifactPostEmitM | |
|
|
||
| SLANG_NO_THROW virtual const char* SLANG_MCALL getDebugBuildIdentifier() SLANG_OVERRIDE; | ||
|
|
||
| // ICooperativeTypesMetadata | ||
| SLANG_NO_THROW virtual SlangUInt SLANG_MCALL getCooperativeMatrixTypeCount() SLANG_OVERRIDE; | ||
| SLANG_NO_THROW virtual SlangResult SLANG_MCALL getCooperativeMatrixTypeByIndex( | ||
| SlangUInt index, | ||
| slang::CooperativeMatrixType* outType) SLANG_OVERRIDE; | ||
| SLANG_NO_THROW virtual SlangUInt SLANG_MCALL getCooperativeMatrixCombinationCount() | ||
| SLANG_OVERRIDE; | ||
| SLANG_NO_THROW virtual SlangResult SLANG_MCALL getCooperativeMatrixCombinationByIndex( | ||
| SlangUInt index, | ||
| slang::CooperativeMatrixCombination* outCombination) SLANG_OVERRIDE; | ||
| SLANG_NO_THROW virtual SlangUInt SLANG_MCALL getCooperativeVectorTypeCount() SLANG_OVERRIDE; | ||
| SLANG_NO_THROW virtual SlangResult SLANG_MCALL getCooperativeVectorTypeByIndex( | ||
| SlangUInt index, | ||
| slang::CooperativeVectorType* outType) SLANG_OVERRIDE; | ||
| SLANG_NO_THROW virtual SlangUInt SLANG_MCALL getCooperativeVectorCombinationCount() | ||
| SLANG_OVERRIDE; | ||
| SLANG_NO_THROW virtual SlangResult SLANG_MCALL getCooperativeVectorCombinationByIndex( | ||
| SlangUInt index, | ||
| slang::CooperativeVectorCombination* outCombination) SLANG_OVERRIDE; | ||
|
|
||
| void* getInterface(const Guid& uuid); | ||
| void* getObject(const Guid& uuid); | ||
|
|
||
|
|
@@ -211,6 +233,10 @@ class ArtifactPostEmitMetadata : public ComBaseObject, public IArtifactPostEmitM | |
|
|
||
| List<ShaderBindingRange> m_usedBindings; | ||
| List<String> m_exportedFunctionMangledNames; | ||
| List<slang::CooperativeMatrixType> m_cooperativeMatrixTypes; | ||
| List<slang::CooperativeMatrixCombination> m_cooperativeMatrixCombinations; | ||
| List<slang::CooperativeVectorType> m_cooperativeVectorTypes; | ||
| List<slang::CooperativeVectorCombination> m_cooperativeVectorCombinations; | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure if |
||
| String m_debugBuildIdentifier; | ||
| }; | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2262,6 +2262,10 @@ Result linkAndOptimizeIR( | |
| auto metadata = new ArtifactPostEmitMetadata; | ||
| outLinkedIR.metadata = metadata; | ||
|
|
||
|
jkwak-work marked this conversation as resolved.
|
||
| // Runs after target-specific lowering so it only captures cooperative types that remain | ||
| // as native constructs visible to the driver (see ICooperativeTypesMetadata docs). | ||
| SLANG_PASS(collectCooperativeMetadata, *metadata); | ||
|
Comment on lines
+2265
to
+2267
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can skip this pass when the cooperative capability is not used or cooperative-vector or cooperative-matrix types are not found earlier. |
||
|
|
||
| if (targetProgram->getOptionSet().getBoolOption(CompilerOptionName::EmbedDownstreamIR)) | ||
| { | ||
| SLANG_PASS(unexportNonEmbeddableIR, target); | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.