Add API to list coopMat/coopVec types and combinations#10076
Add API to list coopMat/coopVec types and combinations#10076cmarcelo wants to merge 1 commit intoshader-slang:masterfrom
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughThis pull request adds comprehensive support for cooperative matrix and vector operations across the Slang compiler. It introduces new public API types and metadata interfaces for exposing cooperative type information, extends IR with five new cooperative instructions and validation logic, implements codegen for multiple targets (CUDA OptiX, HLSL, SPIR-V), and adds three new scalar types (BFloat16, FloatE4M3, FloatE5M2) plus pointer-sized integer types. Changes
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Pull request overview
This pull request adds a new API to expose cooperative matrix type metadata, allowing applications to query which cooperative matrix type combinations are used in compiled shaders. This addresses issue #10021, where applications need to verify driver support for specific cooperative matrix type combinations used in shaders.
Changes:
- Adds new public API types (
SlangScope,SlangCooperativeMatrixUse,SlangCooperativeComponentType,SlangCooperativeMatrixType) andICooperativeMatrixMetadatainterface to slang.h - Implements metadata collection that recursively scans IR to identify all cooperative matrix types used in the compiled code
- Includes comprehensive unit tests for both the new cooperative matrix metadata API and the binarySearch changes from dependency PRs
Reviewed changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
| include/slang.h | Adds new enums, struct, and interface for querying cooperative matrix types used in compiled shaders |
| source/compiler-core/slang-artifact-associated-impl.h | Extends ArtifactPostEmitMetadata to implement ICooperativeMatrixMetadata interface |
| source/compiler-core/slang-artifact-associated-impl.cpp | Implements the cooperative matrix metadata query methods with proper error handling |
| source/slang/slang-ir-metadata.cpp | Adds IR type conversion helpers and recursive metadata collection for cooperative matrix types |
| source/core/slang-list.h | Updates binarySearch to return bitwise negation of insertion index on miss (from PR #10048) |
| source/slang/slang-check-decl.cpp | Replaces binarySearch with indexOf for declaration ordering (from PR #10047) |
| source/slang/slang-language-server-auto-format.cpp | Updates exclusion range check to use >= 0 instead of != -1 (from PR #10048) |
| tools/slang-unit-test/unit-test-cooperative-type-metadata.cpp | Comprehensive tests for the cooperative matrix metadata API |
| tools/slang-unit-test/unit-test-list.cpp | Unit tests for the updated binarySearch behavior |
|
Some context for this PR in #10021 (comment). |
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@source/compiler-core/slang-artifact-associated-impl.cpp`:
- Around line 283-293: ArtifactPostEmitMetadata::getInterface is returning the
wrong interface pointer for the slang::IMetadata GUID (it returns
IArtifactPostEmitMetadata*), violating COM rules; update the GUID checks so that
when guid == slang::IMetadata::getTypeGuid() you return
static_cast<slang::IMetadata*>(this), when guid ==
IArtifactPostEmitMetadata::getTypeGuid() you return
static_cast<IArtifactPostEmitMetadata*>(this), and keep the ISlangUnknown,
ICastable and slang::ICooperativeMatrixMetadata cases returning their
corresponding static_casts so each GUID returns the correctly typed interface
pointer.
In `@source/slang/slang-check-decl.cpp`:
- Around line 3581-3582: The two indexOf() results (from
ancestor->getMembers().indexOf(subAncestor) and indexOf(supAncestor)) must be
validated for -1 before being used in arithmetic; update the code to check for
-1 and either assert (e.g., SLANG_ASSERT) or apply the same fallback/ordering
logic as used in _compareDeclsInCommonParentByOrderOfDeclaration() so a missing
member does not produce incorrect ordering. Locate the variables subIndex and
supIndex in this block and add the check/early-return or normalized ordering
fallback when either index is -1, and log or assert with clear context
referencing subAncestor/supAncestor to aid debugging.
86eccf9 to
5878845
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tools/slang-unit-test/unit-test-cooperative-type-metadata.cpp`:
- Around line 37-50: The test must guard and skip when SPIR-V or the
spvCooperativeMatrixKHR capability is unavailable; before proceeding with
targetDesc/compilation use the project's test-skip pattern (SLANG_IGNORE_TEST)
to check that globalSession->findProfile("spirv_1_6") is non-null and that
int32_t(globalSession->findCapability("spvCooperativeMatrixKHR")) is non-zero
(or otherwise indicates support), and if either check fails call
SLANG_IGNORE_TEST with an explanatory message instead of continuing; update the
logic around globalSession, targetDesc.format (SLANG_SPIRV), and
capabilityOption so the test is skipped early when those features are absent.
5878845 to
187b4bb
Compare
jkwak-work
left a comment
There was a problem hiding this comment.
This PR wouldn't provide the information needed to query the combination of types that the graphics driver supports.
include/slang.h
Outdated
|
|
||
| Query this interface from an `IMetadata*`. | ||
| */ | ||
| struct ICooperativeMatrixMetadata : public ISlangUnknown |
There was a problem hiding this comment.
I am not sure if we want to have a separate struct for the purpose.
Wouldn't be much simpler to move the functions into IMetadata as initially proposed in the issue comment?
There was a problem hiding this comment.
We can just use the existing IArtifactPostEmitMetadata interface.
There was a problem hiding this comment.
I might be wrong but that looks like is not public, should I make it public?
There was a problem hiding this comment.
Wouldn't be much simpler to move the functions into IMetadata as initially proposed in #10021 (comment)?
Wouldn't that break ABI of IMetadata?
There was a problem hiding this comment.
@csyonghe my understanding is that it would break ABI compatibility to add new virtual methods to IMetadata; so on this iteration I've kept as separate interface. Let me know if we still want to just add the functions to IMetadata.
include/slang.h
Outdated
|
|
||
| Query this interface from an `IMetadata*`. | ||
| */ | ||
| struct ICooperativeMatrixMetadata : public ISlangUnknown |
There was a problem hiding this comment.
We can just use the existing IArtifactPostEmitMetadata interface.
187b4bb to
9f61c6b
Compare
9f61c6b to
47bed14
Compare
|
@coderabbitai summary |
✅ Actions performedSummary regeneration triggered. |
47bed14 to
12fda08
Compare
csyonghe
left a comment
There was a problem hiding this comment.
In general, I think we need to refactor the way coopMatMulAdd and other coop-vec intrinsics are implemented in the compiler first. We need to move the lowering logic from hlsl.meta.slang (library code) into each backends first so we can map each instrinsics to its dedicated slang IR opcode to simplify the meta data collection logic.
|
@cmarcelo , were you able to make some progress on this PR? |
Working on the suggestions. Will update PR here in the next few days. |
01070ec to
a00945c
Compare
8d6a9d6 to
ae3fd3c
Compare
40f4137 to
652560c
Compare
Expose cooperative matrix and cooperative vector metadata through a new `ICooperativeTypesMetadata` interface queried from `IMetadata`. This information can be used to query the combinations of types that the graphics driver supports. Closes shader-slang#10021.
There was a problem hiding this comment.
Verdict: ✅ Clean — no significant issues found
This PR adds a new ICooperativeTypesMetadata COM interface (queryable via castAs from IMetadata) that exposes cooperative matrix and cooperative vector type/combination metadata collected after target-specific lowering. The implementation is correct: metadata collection is positioned after validateCooperativeOperations (ensuring safe cast<> usage) and after lowerCooperativeVectors (ensuring only native constructs are reported). Tests cover 7 scenarios across SPIRV, HLSL, CUDA, and CUDA+optix targets, including edge cases (empty shaders, lowered targets, training ops, mixed usage, error arguments).
Changes Overview
Public API (include/slang.h)
- What changed: Added SlangScope enum, four POD structs (CooperativeMatrixType, CooperativeMatrixCombination, CooperativeVectorType, CooperativeVectorCombination), and the ICooperativeTypesMetadata interface with 8 methods (count + getByIndex for each of the four collections). Structs define operator== where needed for deduplication. CooperativeVectorType intentionally omits operator== because entries are aggregated by componentType (merging maxSize and usedForTrainingOp).
COM implementation (source/compiler-core/slang-artifact-associated-impl.cpp, source/compiler-core/slang-artifact-associated-impl.h)
- What changed: ArtifactPostEmitMetadata now additionally inherits ICooperativeTypesMetadata. The getInterface method was extended to handle both IMetadata and ICooperativeTypesMetadata GUIDs (the IMetadata case was previously missing — this PR fixes that gap). All 8 interface methods perform null-pointer and bounds validation returning SLANG_E_INVALID_ARG. Four List<> members store the collected metadata.
IR metadata collection pass (source/slang/slang-ir-metadata.cpp, source/slang/slang-ir-metadata.h)
- What changed: New collectCooperativeMetadata function performs a breadth-first traversal of the IR module, collecting cooperative types from IRCoopMatrixType/IRCoopVectorType instructions (via getResolvedInstForDecorations) and combinations from IRCoopMatMulAdd/IRCoopVecMatMulAdd/training ops. Types are stored in sorted order using custom operator< (defined in the slang namespace within the .cpp) with _insertSortedUnique for deduplication. Cooperative vector types use _insertOrUpdateCooperativeVectorType to merge entries by componentType. Helper functions validate IR values using std::in_range (first usage of this C++20 utility in the codebase — project targets C++20) and explicit enum case enumeration without default to get compiler warnings on enum growth.
Pipeline integration (source/slang/slang-emit.cpp)
- What changed: collectCooperativeMetadata is called immediately after validateCooperativeOperations and after all target-specific lowering passes. This ensures cast<> usage in the metadata collector is safe (validation guarantees operand types) and that targets which lower cooperative types to arrays (plain CUDA, Metal, WGSL, GLSL) correctly report empty metadata lists.
Unit tests (tools/slang-unit-test/unit-test-cooperative-type-metadata.cpp)
- What changed: 859-line test file with 7 test cases covering: subgroup and workgroup matrix types/combinations (SPIRV + CUDA), cooperative vector types/combinations with packed inputs (SPIRV + HLSL + CUDA+optix), training operations (outer product, reduce-sum), mixed training/non-training, lowered vector targets (plain CUDA produces empty), and empty shader baseline. Tests validate null-pointer and out-of-bounds error returns, exact count matching, and per-entry field correctness.
Expose cooperative matrix and cooperative vector metadata through
a new
ICooperativeTypesMetadatainterface queried fromIMetadata.This information can be used to query the combinations of types that
the graphics driver supports.
Closes #10021.