diff --git a/include/slang.h b/include/slang.h index b80462210f0..bb7582207be 100644 --- a/include/slang.h +++ b/include/slang.h @@ -857,6 +857,15 @@ typedef uint32_t SlangSizeT; SLANG_STAGE_PIXEL = SLANG_STAGE_FRAGMENT, }; + typedef SlangUInt32 SlangScopeIntegral; + enum SlangScope : SlangScopeIntegral + { + SLANG_SCOPE_NONE, + SLANG_SCOPE_THREAD, + SLANG_SCOPE_WAVE, + SLANG_SCOPE_THREAD_GROUP, + }; + typedef SlangUInt32 SlangCooperativeMatrixUseIntegral; enum SlangCooperativeMatrixUse : SlangCooperativeMatrixUseIntegral { @@ -4488,6 +4497,136 @@ struct IMetadata : public ISlangCastable }; #define SLANG_UUID_IMetadata IMetadata::getTypeGuid() +struct CooperativeMatrixType +{ + // Component type `NONE` means this type is not valid. + SlangScalarType componentType = SLANG_SCALAR_TYPE_NONE; + SlangScope scope = SLANG_SCOPE_NONE; + + uint32_t rowCount = 0; + uint32_t columnCount = 0; + + SlangCooperativeMatrixUse use = SLANG_COOPERATIVE_MATRIX_USE_A; + + bool operator==(const CooperativeMatrixType& other) const + { + return componentType == other.componentType && scope == other.scope && + rowCount == other.rowCount && columnCount == other.columnCount && use == other.use; + } +}; + +struct CooperativeMatrixCombination +{ + // Number of rows of matrix A and the result. + uint32_t m = 0; + // Number of columns of matrix B and the result. + uint32_t n = 0; + // Shared inner dimension: columns of A and rows of B. + uint32_t k = 0; + + SlangScalarType componentTypeA = SLANG_SCALAR_TYPE_NONE; + SlangScalarType componentTypeB = SLANG_SCALAR_TYPE_NONE; + SlangScalarType componentTypeC = SLANG_SCALAR_TYPE_NONE; + SlangScalarType componentTypeResult = SLANG_SCALAR_TYPE_NONE; + + SlangBool saturate = false; + SlangScope scope = SLANG_SCOPE_NONE; + + bool operator==(const CooperativeMatrixCombination& other) const + { + return m == other.m && n == other.n && k == other.k && + componentTypeA == other.componentTypeA && componentTypeB == other.componentTypeB && + componentTypeC == other.componentTypeC && + componentTypeResult == other.componentTypeResult && saturate == other.saturate && + scope == other.scope; + } +}; + +struct CooperativeVectorType +{ + SlangScalarType componentType = SLANG_SCALAR_TYPE_NONE; + + // Maximum element count for this component type across all cooperative + // vectors in the shader. Zero means the type appears only in training + // operations; non-zero does not preclude `usedForTrainingOp` being true. + uint32_t maxSize = 0; + + // Indicates this component type is used as the accumulation/storage type for cooperative + // vector training operations such as outer-product accumulation and reduce-sum + // accumulation. + SlangBool usedForTrainingOp = false; +}; + +struct CooperativeVectorCombination +{ + SlangScalarType inputType = SLANG_SCALAR_TYPE_NONE; + SlangScalarType inputInterpretation = SLANG_SCALAR_TYPE_NONE; + // Number of logical elements packed into each physical input element. + // For example, this is 4 when four int8 values are packed into one uint32 input element. + uint32_t inputPackingFactor = 1; + SlangScalarType matrixInterpretation = SLANG_SCALAR_TYPE_NONE; + // `NONE` means the operation has no bias operand/matrix. + SlangScalarType biasInterpretation = SLANG_SCALAR_TYPE_NONE; + SlangScalarType resultType = SLANG_SCALAR_TYPE_NONE; + SlangBool transpose = false; + + bool operator==(const CooperativeVectorCombination& other) const + { + return inputType == other.inputType && inputInterpretation == other.inputInterpretation && + inputPackingFactor == other.inputPackingFactor && + matrixInterpretation == other.matrixInterpretation && + biasInterpretation == other.biasInterpretation && resultType == other.resultType && + transpose == other.transpose; + } +}; + +/** Cooperative matrix and vector metadata. + +This interface reports the cooperative matrix/vector type information that a compiled target uses, +including both types and certain type combinations required to execute some operations (like matrix +multiplication). + +Applications can use this metadata to compare shader requirements against the capabilities exposed +by the target API/driver (for example Vulkan cooperative matrix/vector property queries, or +analogous APIs on other backends). + +Metadata is collected from the IR after target-specific lowering, so it only reflects cooperative +types that survive as native constructs in the final output. Targets that lower cooperative types +into ordinary arrays will report empty lists. + +Lists are exposed using `get*Count()` plus `get*ByIndex()` methods, where the count returns the +number of elements currently available and valid indices are in the range `[0, count)`. + +Cast from an `IMetadata*` using `castAs()`. +*/ +struct ICooperativeTypesMetadata : public ISlangCastable +{ + SLANG_COM_INTERFACE( + 0x64c4d536, + 0xd949, + 0x49c3, + {0x9f, 0xde, 0x3f, 0x0f, 0x9c, 0x6f, 0x01, 0x31}) + + virtual SlangUInt SLANG_MCALL getCooperativeMatrixTypeCount() = 0; + virtual SlangResult SLANG_MCALL + getCooperativeMatrixTypeByIndex(SlangUInt index, CooperativeMatrixType* outType) = 0; + + virtual SlangUInt SLANG_MCALL getCooperativeMatrixCombinationCount() = 0; + virtual SlangResult SLANG_MCALL getCooperativeMatrixCombinationByIndex( + SlangUInt index, + CooperativeMatrixCombination* outCombination) = 0; + + virtual SlangUInt SLANG_MCALL getCooperativeVectorTypeCount() = 0; + virtual SlangResult SLANG_MCALL + getCooperativeVectorTypeByIndex(SlangUInt index, CooperativeVectorType* outType) = 0; + + virtual SlangUInt SLANG_MCALL getCooperativeVectorCombinationCount() = 0; + virtual SlangResult SLANG_MCALL getCooperativeVectorCombinationByIndex( + SlangUInt index, + CooperativeVectorCombination* outCombination) = 0; +}; + #define SLANG_UUID_ICooperativeTypesMetadata ICooperativeTypesMetadata::getTypeGuid() + /** Compile result for storing and retrieving multiple output blobs. This is needed for features such as separate debug compilation which output both base and debug spirv. diff --git a/source/compiler-core/slang-artifact-associated-impl.cpp b/source/compiler-core/slang-artifact-associated-impl.cpp index 11bfec8bced..886bace4af4 100644 --- a/source/compiler-core/slang-artifact-associated-impl.cpp +++ b/source/compiler-core/slang-artifact-associated-impl.cpp @@ -287,6 +287,10 @@ void* ArtifactPostEmitMetadata::getInterface(const Guid& guid) { return static_cast(this); } + if (guid == slang::IMetadata::getTypeGuid()) + return static_cast(this); + if (guid == slang::ICooperativeTypesMetadata::getTypeGuid()) + return static_cast(this); return nullptr; } @@ -344,5 +348,81 @@ const char* ArtifactPostEmitMetadata::getDebugBuildIdentifier() return m_debugBuildIdentifier.getBuffer(); } +SlangUInt ArtifactPostEmitMetadata::getCooperativeMatrixTypeCount() +{ + return SlangUInt(m_cooperativeMatrixTypes.getCount()); +} + +SlangResult ArtifactPostEmitMetadata::getCooperativeMatrixTypeByIndex( + SlangUInt index, + slang::CooperativeMatrixType* outType) +{ + if (!outType) + return SLANG_E_INVALID_ARG; + + if (index >= SlangUInt(m_cooperativeMatrixTypes.getCount())) + return SLANG_E_INVALID_ARG; + + *outType = m_cooperativeMatrixTypes[Index(index)]; + return SLANG_OK; +} + +SlangUInt ArtifactPostEmitMetadata::getCooperativeMatrixCombinationCount() +{ + return SlangUInt(m_cooperativeMatrixCombinations.getCount()); +} + +SlangResult ArtifactPostEmitMetadata::getCooperativeMatrixCombinationByIndex( + SlangUInt index, + slang::CooperativeMatrixCombination* outCombination) +{ + if (!outCombination) + return SLANG_E_INVALID_ARG; + + if (index >= SlangUInt(m_cooperativeMatrixCombinations.getCount())) + return SLANG_E_INVALID_ARG; + + *outCombination = m_cooperativeMatrixCombinations[Index(index)]; + return SLANG_OK; +} + +SlangUInt ArtifactPostEmitMetadata::getCooperativeVectorTypeCount() +{ + return SlangUInt(m_cooperativeVectorTypes.getCount()); +} + +SlangResult ArtifactPostEmitMetadata::getCooperativeVectorTypeByIndex( + SlangUInt index, + slang::CooperativeVectorType* outType) +{ + if (!outType) + return SLANG_E_INVALID_ARG; + + if (index >= SlangUInt(m_cooperativeVectorTypes.getCount())) + return SLANG_E_INVALID_ARG; + + *outType = m_cooperativeVectorTypes[Index(index)]; + return SLANG_OK; +} + +SlangUInt ArtifactPostEmitMetadata::getCooperativeVectorCombinationCount() +{ + return SlangUInt(m_cooperativeVectorCombinations.getCount()); +} + +SlangResult ArtifactPostEmitMetadata::getCooperativeVectorCombinationByIndex( + SlangUInt index, + slang::CooperativeVectorCombination* outCombination) +{ + if (!outCombination) + return SLANG_E_INVALID_ARG; + + if (index >= SlangUInt(m_cooperativeVectorCombinations.getCount())) + return SLANG_E_INVALID_ARG; + + *outCombination = m_cooperativeVectorCombinations[Index(index)]; + return SLANG_OK; +} + } // namespace Slang diff --git a/source/compiler-core/slang-artifact-associated-impl.h b/source/compiler-core/slang-artifact-associated-impl.h index 9f60c51b6be..964d3962803 100644 --- a/source/compiler-core/slang-artifact-associated-impl.h +++ b/source/compiler-core/slang-artifact-associated-impl.h @@ -174,7 +174,9 @@ struct ShaderBindingRange } }; -class ArtifactPostEmitMetadata : public ComBaseObject, public IArtifactPostEmitMetadata +class ArtifactPostEmitMetadata : public ComBaseObject, + public IArtifactPostEmitMetadata, + public slang::ICooperativeTypesMetadata { public: typedef ArtifactPostEmitMetadata ThisType; @@ -201,6 +203,26 @@ class ArtifactPostEmitMetadata : public ComBaseObject, public IArtifactPostEmitM SLANG_NO_THROW virtual const char* SLANG_MCALL getDebugBuildIdentifier() SLANG_OVERRIDE; + // ICooperativeTypesMetadata + SLANG_NO_THROW virtual SlangUInt SLANG_MCALL getCooperativeMatrixTypeCount() SLANG_OVERRIDE; + SLANG_NO_THROW virtual SlangResult SLANG_MCALL getCooperativeMatrixTypeByIndex( + SlangUInt index, + slang::CooperativeMatrixType* outType) SLANG_OVERRIDE; + SLANG_NO_THROW virtual SlangUInt SLANG_MCALL getCooperativeMatrixCombinationCount() + SLANG_OVERRIDE; + SLANG_NO_THROW virtual SlangResult SLANG_MCALL getCooperativeMatrixCombinationByIndex( + SlangUInt index, + slang::CooperativeMatrixCombination* outCombination) SLANG_OVERRIDE; + SLANG_NO_THROW virtual SlangUInt SLANG_MCALL getCooperativeVectorTypeCount() SLANG_OVERRIDE; + SLANG_NO_THROW virtual SlangResult SLANG_MCALL getCooperativeVectorTypeByIndex( + SlangUInt index, + slang::CooperativeVectorType* outType) SLANG_OVERRIDE; + SLANG_NO_THROW virtual SlangUInt SLANG_MCALL getCooperativeVectorCombinationCount() + SLANG_OVERRIDE; + SLANG_NO_THROW virtual SlangResult SLANG_MCALL getCooperativeVectorCombinationByIndex( + SlangUInt index, + slang::CooperativeVectorCombination* outCombination) SLANG_OVERRIDE; + void* getInterface(const Guid& uuid); void* getObject(const Guid& uuid); @@ -211,6 +233,10 @@ class ArtifactPostEmitMetadata : public ComBaseObject, public IArtifactPostEmitM List m_usedBindings; List m_exportedFunctionMangledNames; + List m_cooperativeMatrixTypes; + List m_cooperativeMatrixCombinations; + List m_cooperativeVectorTypes; + List m_cooperativeVectorCombinations; String m_debugBuildIdentifier; }; diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index ee08cfbef05..bca195b7bf7 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -2262,6 +2262,10 @@ Result linkAndOptimizeIR( auto metadata = new ArtifactPostEmitMetadata; outLinkedIR.metadata = metadata; + // Runs after target-specific lowering so it only captures cooperative types that remain + // as native constructs visible to the driver (see ICooperativeTypesMetadata docs). + SLANG_PASS(collectCooperativeMetadata, *metadata); + if (targetProgram->getOptionSet().getBoolOption(CompilerOptionName::EmbedDownstreamIR)) { SLANG_PASS(unexportNonEmbeddableIR, target); diff --git a/source/slang/slang-ir-metadata.cpp b/source/slang/slang-ir-metadata.cpp index a056d3683e5..239feac62a5 100644 --- a/source/slang/slang-ir-metadata.cpp +++ b/source/slang/slang-ir-metadata.cpp @@ -5,14 +5,69 @@ #include "slang-ir-insts.h" #include "slang-ir.h" -namespace Slang +#include + +// Define operator< for public cooperative type structs, used internally. +namespace slang +{ +bool operator<(const CooperativeMatrixType& a, const CooperativeMatrixType& b) { + if (a.componentType != b.componentType) + return a.componentType < b.componentType; + if (a.scope != b.scope) + return a.scope < b.scope; + if (a.rowCount != b.rowCount) + return a.rowCount < b.rowCount; + if (a.columnCount != b.columnCount) + return a.columnCount < b.columnCount; + return a.use < b.use; +} +bool operator<(const CooperativeMatrixCombination& a, const CooperativeMatrixCombination& b) +{ + if (a.m != b.m) + return a.m < b.m; + if (a.n != b.n) + return a.n < b.n; + if (a.k != b.k) + return a.k < b.k; + if (a.componentTypeA != b.componentTypeA) + return a.componentTypeA < b.componentTypeA; + if (a.componentTypeB != b.componentTypeB) + return a.componentTypeB < b.componentTypeB; + if (a.componentTypeC != b.componentTypeC) + return a.componentTypeC < b.componentTypeC; + if (a.componentTypeResult != b.componentTypeResult) + return a.componentTypeResult < b.componentTypeResult; + if (a.saturate != b.saturate) + return a.saturate < b.saturate; + return a.scope < b.scope; +} + +bool operator<(const CooperativeVectorCombination& a, const CooperativeVectorCombination& b) +{ + if (a.inputType != b.inputType) + return a.inputType < b.inputType; + if (a.inputInterpretation != b.inputInterpretation) + return a.inputInterpretation < b.inputInterpretation; + if (a.inputPackingFactor != b.inputPackingFactor) + return a.inputPackingFactor < b.inputPackingFactor; + if (a.matrixInterpretation != b.matrixInterpretation) + return a.matrixInterpretation < b.matrixInterpretation; + if (a.biasInterpretation != b.biasInterpretation) + return a.biasInterpretation < b.biasInterpretation; + if (a.resultType != b.resultType) + return a.resultType < b.resultType; + return a.transpose < b.transpose; +} +} // namespace slang + +namespace Slang +{ // This file currently implements a pass that collects information about the shader parameters that // are referenced in the IR. It's named 'metadata' in order to support other potential code // analysis scenarios in the future. - // Inserts a single resource binding (which takes `count` slots, where 0 means unbounded) into the // list of resource ranges. static void _insertBinding( @@ -143,4 +198,407 @@ void collectMetadata(const IRModule* irModule, ArtifactPostEmitMetadata& outMeta } } +static SlangScalarType _getScalarTypeFromIRType(IRType* type) +{ + switch (type->getOp()) + { + case kIROp_HalfType: + return SLANG_SCALAR_TYPE_FLOAT16; + case kIROp_FloatType: + return SLANG_SCALAR_TYPE_FLOAT32; + case kIROp_DoubleType: + return SLANG_SCALAR_TYPE_FLOAT64; + case kIROp_Int8Type: + return SLANG_SCALAR_TYPE_INT8; + case kIROp_Int16Type: + return SLANG_SCALAR_TYPE_INT16; + case kIROp_IntType: + return SLANG_SCALAR_TYPE_INT32; + case kIROp_Int64Type: + return SLANG_SCALAR_TYPE_INT64; + case kIROp_UInt8Type: + return SLANG_SCALAR_TYPE_UINT8; + case kIROp_UInt16Type: + return SLANG_SCALAR_TYPE_UINT16; + case kIROp_UIntType: + return SLANG_SCALAR_TYPE_UINT32; + case kIROp_UInt64Type: + return SLANG_SCALAR_TYPE_UINT64; + case kIROp_BFloat16Type: + return SLANG_SCALAR_TYPE_BFLOAT16; + case kIROp_FloatE4M3Type: + return SLANG_SCALAR_TYPE_FLOAT_E4M3; + case kIROp_FloatE5M2Type: + return SLANG_SCALAR_TYPE_FLOAT_E5M2; + default: + break; + } + return SLANG_SCALAR_TYPE_NONE; +} + +static bool _tryGetIntLiteralValue(IRInst* inst, IRIntegerValue& outValue) +{ + if (auto intLit = as(inst)) + { + outValue = intLit->getValue(); + return true; + } + return false; +} + +static SlangScope _getCooperativeMatrixScope(IRInst* scopeInst) +{ + IRIntegerValue val = 0; + if (!_tryGetIntLiteralValue(scopeInst, val) || !std::in_range(val)) + return SLANG_SCOPE_NONE; + switch (MemoryScope(val)) + { + case MemoryScope::Workgroup: + return SLANG_SCOPE_THREAD_GROUP; + case MemoryScope::Subgroup: + return SLANG_SCOPE_WAVE; + case MemoryScope::Invocation: + return SLANG_SCOPE_THREAD; + default: + return SLANG_SCOPE_NONE; + } +} + +static SlangScalarType _getCooperativeVectorInterpretation(IRInst* interpInst) +{ + IRIntegerValue val = 0; + if (!_tryGetIntLiteralValue(interpInst, val) || !std::in_range(val)) + return SLANG_SCALAR_TYPE_NONE; + + auto scalarType = SlangScalarType(val); + switch (scalarType) + { + case SLANG_SCALAR_TYPE_NONE: + case SLANG_SCALAR_TYPE_VOID: + case SLANG_SCALAR_TYPE_BOOL: + case SLANG_SCALAR_TYPE_INT32: + case SLANG_SCALAR_TYPE_UINT32: + case SLANG_SCALAR_TYPE_INT64: + case SLANG_SCALAR_TYPE_UINT64: + case SLANG_SCALAR_TYPE_FLOAT16: + case SLANG_SCALAR_TYPE_FLOAT32: + case SLANG_SCALAR_TYPE_FLOAT64: + case SLANG_SCALAR_TYPE_INT8: + case SLANG_SCALAR_TYPE_UINT8: + case SLANG_SCALAR_TYPE_INT16: + case SLANG_SCALAR_TYPE_UINT16: + case SLANG_SCALAR_TYPE_INTPTR: + case SLANG_SCALAR_TYPE_UINTPTR: + case SLANG_SCALAR_TYPE_BFLOAT16: + case SLANG_SCALAR_TYPE_FLOAT_E4M3: + case SLANG_SCALAR_TYPE_FLOAT_E5M2: + return scalarType; + } + // No default case above: want compiler warning if enum grows. + return SLANG_SCALAR_TYPE_NONE; +} + +static slang::CooperativeMatrixType _getCooperativeMatrixType(IRInst* inst) +{ + slang::CooperativeMatrixType type = {}; + + auto coopMatType = as(inst); + if (!coopMatType) + return {}; + + type.componentType = _getScalarTypeFromIRType(coopMatType->getElementType()); + if (type.componentType == SLANG_SCALAR_TYPE_NONE) + return {}; + + if (!as(coopMatType->getRowCount()) || !as(coopMatType->getColumnCount()) || + !as(coopMatType->getMatrixUse())) + return {}; + + type.scope = _getCooperativeMatrixScope(coopMatType->getScope()); + if (type.scope == SLANG_SCOPE_NONE) + return {}; + + IRIntegerValue rowValue = getIntVal(coopMatType->getRowCount()); + IRIntegerValue columnValue = getIntVal(coopMatType->getColumnCount()); + if (!std::in_range(rowValue) || !std::in_range(columnValue)) + return {}; + + type.rowCount = uint32_t(rowValue); + type.columnCount = uint32_t(columnValue); + + IRIntegerValue matrixUseValue = getIntVal(coopMatType->getMatrixUse()); + if (!std::in_range(matrixUseValue)) + return {}; + + auto matrixUse = SlangCooperativeMatrixUse(matrixUseValue); + switch (matrixUse) + { + case SLANG_COOPERATIVE_MATRIX_USE_A: + case SLANG_COOPERATIVE_MATRIX_USE_B: + case SLANG_COOPERATIVE_MATRIX_USE_ACCUMULATOR: + type.use = matrixUse; + return type; + } + // No default case above: want compiler warning if enum grows. + return {}; +} + +static slang::CooperativeVectorType _getCooperativeVectorType(IRInst* inst) +{ + slang::CooperativeVectorType type = {}; + + auto coopVecType = as(inst); + if (!coopVecType) + return {}; + + type.componentType = _getScalarTypeFromIRType(coopVecType->getElementType()); + if (type.componentType == SLANG_SCALAR_TYPE_NONE) + return {}; + + if (!as(coopVecType->getElementCount())) + return {}; + + IRIntegerValue maxSizeValue = getIntVal(coopVecType->getElementCount()); + if (!std::in_range(maxSizeValue)) + return {}; + + type.maxSize = uint32_t(maxSizeValue); + return type; +} + +template +static Index lowerBound(const List& list, const T2& value, Compare compare) +{ + Index imin = 0; + Index imax = list.getCount(); + while (imax > imin) + { + Index imid = imin + ((imax - imin) >> 1); + if (compare(list[imid], value) < 0) + imin = imid + 1; + else + imax = imid; + } + return imin; +} + +template +static Index lowerBound(const List& list, const T2& value) +{ + return lowerBound( + list, + value, + [](const T& currentValue, const T2& searchValue) -> int + { + if (currentValue < searchValue) + return -1; + if (currentValue == searchValue) + return 0; + return +1; + }); +} + +template +static void _insertSortedUnique(List& list, const T& value) +{ + Index insertIndex = lowerBound(list, value); + if (insertIndex >= list.getCount() || !(list[insertIndex] == value)) + list.insert(insertIndex, value); +} + +static void _insertOrUpdateCooperativeVectorType( + List& list, + SlangScalarType componentType, + uint32_t maxSize, + bool usedForTrainingOp) +{ + if (componentType == SLANG_SCALAR_TYPE_NONE) + return; + + slang::CooperativeVectorType key = {}; + key.componentType = componentType; + + // Custom compare function since different maxSize and usedForTrainingOp + // will be collected together into one entry. + auto compareByComponentType = [](const slang::CooperativeVectorType& a, + const slang::CooperativeVectorType& b) -> int + { + if (a.componentType < b.componentType) + return -1; + if (a.componentType > b.componentType) + return 1; + return 0; + }; + + Index insertIndex = lowerBound(list, key, compareByComponentType); + if (insertIndex < list.getCount() && compareByComponentType(list[insertIndex], key) == 0) + { + auto& existing = list[insertIndex]; + if (existing.maxSize < maxSize) + existing.maxSize = maxSize; + if (usedForTrainingOp) + existing.usedForTrainingOp = true; + } + else + { + key.maxSize = maxSize; + key.usedForTrainingOp = usedForTrainingOp; + list.insert(insertIndex, key); + } +} + +// Operand literal types are guaranteed by validateCooperativeOperations which runs +// before this pass, so cast<> is used instead of as<> + null-check. +static void collectMetadataFromCooperativeVectorCombination( + IRInst* inst, + ArtifactPostEmitMetadata& outMetadata) +{ + auto matMulAdd = as(inst); + if (!matMulAdd) + return; + + slang::CooperativeVectorCombination combination = {}; + + combination.inputType = + _getCooperativeVectorType(matMulAdd->getInput()->getDataType()).componentType; + + combination.inputInterpretation = + _getCooperativeVectorInterpretation(matMulAdd->getInputInterpretation()); + + IRIntegerValue packingFactorValue = + cast(matMulAdd->getInputInterpretationPackingFactor())->getValue(); + if (!std::in_range(packingFactorValue)) + return; + combination.inputPackingFactor = uint32_t(packingFactorValue); + + combination.matrixInterpretation = + _getCooperativeVectorInterpretation(matMulAdd->getMatrixInterpretation()); + + combination.biasInterpretation = SLANG_SCALAR_TYPE_NONE; + if (auto biasInterpretation = matMulAdd->getBiasInterpretation()) + combination.biasInterpretation = _getCooperativeVectorInterpretation(biasInterpretation); + + combination.resultType = _getCooperativeVectorType(matMulAdd->getDataType()).componentType; + + combination.transpose = cast(matMulAdd->getTranspose())->getValue(); + + if (!combination.inputType || !combination.inputInterpretation || + !combination.matrixInterpretation || !combination.resultType) + { + return; + } + + _insertSortedUnique(outMetadata.m_cooperativeVectorCombinations, combination); +} + +static void collectMetadataFromCooperativeVectorTrainingUsage( + IRInst* inst, + ArtifactPostEmitMetadata& outMetadata) +{ + SlangScalarType trainingType = SLANG_SCALAR_TYPE_NONE; + uint32_t maxSize = 0; + + if (auto outerProduct = as(inst)) + { + // For outer-product accumulation, the training-specific type we want to surface is the + // matrix accumulation/storage interpretation. That type does not correspond to a + // cooperative vector width in the public Vulkan API, so keep `maxSize` at zero instead of + // inheriting the operand vector sizes. + trainingType = _getCooperativeVectorInterpretation(outerProduct->getMatrixInterpretation()); + } + else if (auto reduceSum = as(inst)) + { + auto vectorType = _getCooperativeVectorType(reduceSum->getValue()->getDataType()); + trainingType = vectorType.componentType; + maxSize = vectorType.maxSize; + } + else + { + return; + } + + _insertOrUpdateCooperativeVectorType( + outMetadata.m_cooperativeVectorTypes, + trainingType, + maxSize, + true); +} + +// Operand literal types are guaranteed by validateCooperativeOperations which runs +// before this pass, so cast<> is used instead of as<> + null-check. +static void collectMetadataFromCooperativeMatrixCombination( + IRInst* inst, + ArtifactPostEmitMetadata& outMetadata) +{ + auto matMulAdd = as(inst); + if (!matMulAdd) + return; + + slang::CooperativeMatrixType typeA = + _getCooperativeMatrixType(matMulAdd->getMatA()->getDataType()); + slang::CooperativeMatrixType typeB = + _getCooperativeMatrixType(matMulAdd->getMatB()->getDataType()); + slang::CooperativeMatrixType typeC = + _getCooperativeMatrixType(matMulAdd->getMatC()->getDataType()); + slang::CooperativeMatrixType typeResult = _getCooperativeMatrixType(inst->getDataType()); + + if (!typeA.componentType || !typeB.componentType || !typeC.componentType || + !typeResult.componentType) + { + return; + } + + slang::CooperativeMatrixCombination combination = {}; + combination.m = typeA.rowCount; + combination.n = typeB.columnCount; + combination.k = typeA.columnCount; + combination.componentTypeA = typeA.componentType; + combination.componentTypeB = typeB.componentType; + combination.componentTypeC = typeC.componentType; + combination.componentTypeResult = typeResult.componentType; + combination.saturate = cast(matMulAdd->getSaturatingAccumulation())->getValue(); + // All four matrices are required by validateCoopMatMulAdd to have the same scope. + combination.scope = typeA.scope; + + _insertSortedUnique(outMetadata.m_cooperativeMatrixCombinations, combination); +} + +static void collectCooperativeMetadataFromInst(IRInst* inst, ArtifactPostEmitMetadata& outMetadata) +{ + auto resolved = getResolvedInstForDecorations(inst); + + if (as(resolved)) + { + auto matrixType = _getCooperativeMatrixType(resolved); + if (matrixType.componentType) + _insertSortedUnique(outMetadata.m_cooperativeMatrixTypes, matrixType); + } + else if (as(resolved)) + { + auto vectorType = _getCooperativeVectorType(resolved); + _insertOrUpdateCooperativeVectorType( + outMetadata.m_cooperativeVectorTypes, + vectorType.componentType, + vectorType.maxSize, + false); + } + + collectMetadataFromCooperativeMatrixCombination(inst, outMetadata); + collectMetadataFromCooperativeVectorCombination(inst, outMetadata); + collectMetadataFromCooperativeVectorTrainingUsage(inst, outMetadata); +} + +void collectCooperativeMetadata(const IRModule* irModule, ArtifactPostEmitMetadata& outMetadata) +{ + List insts; + findAllInstsBreadthFirst(irModule->getModuleInst(), insts); + + for (auto inst : insts) + { + if (as(inst)) + continue; + collectCooperativeMetadataFromInst(inst, outMetadata); + } +} + } // namespace Slang diff --git a/source/slang/slang-ir-metadata.h b/source/slang/slang-ir-metadata.h index 964d7bb1417..5849f909234 100644 --- a/source/slang/slang-ir-metadata.h +++ b/source/slang/slang-ir-metadata.h @@ -7,6 +7,8 @@ namespace Slang class ArtifactPostEmitMetadata; struct IRModule; +void collectCooperativeMetadata(const IRModule* irModule, ArtifactPostEmitMetadata& outMetadata); + void collectMetadata(const IRModule* irModule, ArtifactPostEmitMetadata& outMetadata); } // namespace Slang diff --git a/tools/slang-unit-test/unit-test-cooperative-type-metadata.cpp b/tools/slang-unit-test/unit-test-cooperative-type-metadata.cpp new file mode 100644 index 00000000000..2a413fa9b49 --- /dev/null +++ b/tools/slang-unit-test/unit-test-cooperative-type-metadata.cpp @@ -0,0 +1,861 @@ +// unit-test-cooperative-type-metadata.cpp + +#include "core/slang-list.h" +#include "core/slang-string.h" +#include "slang-com-ptr.h" +#include "slang.h" +#include "unit-test/slang-unit-test.h" + +using namespace Slang; + +struct CooperativeMetadataTargetDesc +{ + const char* name; + SlangCompileTarget target; + const char* profileName; + const char* const* capabilityNames; + int capabilityCount; +}; + +static ComPtr _compileAndGetCooperativeMetadata( + const char* source, + const char* moduleNameBase, + const CooperativeMetadataTargetDesc& target) +{ + ComPtr globalSession; + SlangResult res = slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()); + SLANG_CHECK(res == SLANG_OK); + if (res != SLANG_OK) + return nullptr; + + res = globalSession->checkCompileTargetSupport(target.target); + SLANG_CHECK(res == SLANG_OK); + if (res != SLANG_OK) + return nullptr; + + List capabilityOptions; + for (int i = 0; i < target.capabilityCount; ++i) + { + auto cap = globalSession->findCapability(target.capabilityNames[i]); + SLANG_CHECK(cap != SLANG_CAPABILITY_UNKNOWN); + if (cap == SLANG_CAPABILITY_UNKNOWN) + return nullptr; + + slang::CompilerOptionEntry entry = {}; + entry.name = slang::CompilerOptionName::Capability; + entry.value.kind = slang::CompilerOptionValueKind::Int; + entry.value.intValue0 = int32_t(cap); + capabilityOptions.add(entry); + } + + slang::TargetDesc targetDesc = {}; + targetDesc.format = target.target; + if (target.profileName) + { + targetDesc.profile = globalSession->findProfile(target.profileName); + SLANG_CHECK(targetDesc.profile != SLANG_PROFILE_UNKNOWN); + if (targetDesc.profile == SLANG_PROFILE_UNKNOWN) + return nullptr; + } + targetDesc.compilerOptionEntries = capabilityOptions.getBuffer(); + targetDesc.compilerOptionEntryCount = capabilityOptions.getCount(); + + slang::SessionDesc sessionDesc = {}; + sessionDesc.targetCount = 1; + sessionDesc.targets = &targetDesc; + + ComPtr session; + res = globalSession->createSession(sessionDesc, session.writeRef()); + SLANG_CHECK(res == SLANG_OK); + if (res != SLANG_OK) + return nullptr; + + String moduleName; + moduleName.append(moduleNameBase); + moduleName.append("_"); + moduleName.append(target.name); + + String fileName; + fileName.append(moduleName); + fileName.append(".slang"); + + ComPtr diagnostics; + auto module = session->loadModuleFromSourceString( + moduleName.getBuffer(), + fileName.getBuffer(), + source, + diagnostics.writeRef()); + SLANG_CHECK(module != nullptr); + if (!module) + return nullptr; + + ComPtr entryPoint; + res = module->findAndCheckEntryPoint( + "computeMain", + SLANG_STAGE_COMPUTE, + entryPoint.writeRef(), + diagnostics.writeRef()); + SLANG_CHECK(res == SLANG_OK); + if (res != SLANG_OK) + return nullptr; + + ComPtr compositeProgram; + slang::IComponentType* components[] = {module, entryPoint}; + res = session->createCompositeComponentType( + components, + 2, + compositeProgram.writeRef(), + diagnostics.writeRef()); + SLANG_CHECK(res == SLANG_OK); + if (res != SLANG_OK) + return nullptr; + + ComPtr linkedProgram; + res = compositeProgram->link(linkedProgram.writeRef(), diagnostics.writeRef()); + SLANG_CHECK(res == SLANG_OK); + if (res != SLANG_OK) + return nullptr; + + ComPtr metadata; + res = linkedProgram->getTargetMetadata(0, metadata.writeRef(), diagnostics.writeRef()); + SLANG_CHECK(res == SLANG_OK); + if (res != SLANG_OK) + return nullptr; + + auto ptr = static_cast( + metadata->castAs(slang::ICooperativeTypesMetadata::getTypeGuid())); + SLANG_CHECK(ptr != nullptr); + if (!ptr) + return nullptr; + + return ComPtr(ptr); +} + +template +static void _checkListContainsEachExpectedExactlyOnce( + SlangUInt actualCount, + F getByIndex, + const T* expected, + int expectedCount) +{ + List foundCounts; + foundCounts.setCount(expectedCount); + for (int i = 0; i < expectedCount; ++i) + foundCounts[i] = 0; + + SLANG_CHECK(actualCount == SlangUInt(expectedCount)); + + for (SlangUInt i = 0; i < actualCount; ++i) + { + T actual = getByIndex(i); + for (int j = 0; j < expectedCount; ++j) + { + if (actual == expected[j]) + foundCounts[j]++; + } + } + + for (int j = 0; j < expectedCount; ++j) + SLANG_CHECK(foundCounts[j] == 1); +} + +static void _validateMatrixMetadata( + slang::ICooperativeTypesMetadata* metadata, + const slang::CooperativeMatrixType* expectedTypes, + int expectedTypeCount, + const slang::CooperativeMatrixCombination* expectedCombinations, + int expectedCombinationCount) +{ + auto matrixTypeCount = metadata->getCooperativeMatrixTypeCount(); + _checkListContainsEachExpectedExactlyOnce( + matrixTypeCount, + [&](SlangUInt i) + { + slang::CooperativeMatrixType type = {}; + SLANG_CHECK(metadata->getCooperativeMatrixTypeByIndex(i, &type) == SLANG_OK); + SLANG_CHECK(type.componentType != SLANG_SCALAR_TYPE_NONE); + return type; + }, + expectedTypes, + expectedTypeCount); + + auto combinationCount = metadata->getCooperativeMatrixCombinationCount(); + _checkListContainsEachExpectedExactlyOnce( + combinationCount, + [&](SlangUInt i) + { + slang::CooperativeMatrixCombination combination = {}; + SLANG_CHECK( + metadata->getCooperativeMatrixCombinationByIndex(i, &combination) == SLANG_OK); + SLANG_CHECK(combination.componentTypeA != SLANG_SCALAR_TYPE_NONE); + SLANG_CHECK(combination.componentTypeB != SLANG_SCALAR_TYPE_NONE); + SLANG_CHECK(combination.componentTypeC != SLANG_SCALAR_TYPE_NONE); + SLANG_CHECK(combination.componentTypeResult != SLANG_SCALAR_TYPE_NONE); + return combination; + }, + expectedCombinations, + expectedCombinationCount); + + SLANG_CHECK(metadata->getCooperativeMatrixTypeByIndex(0, nullptr) == SLANG_E_INVALID_ARG); + + slang::CooperativeMatrixType invalidType = {}; + SLANG_CHECK( + metadata->getCooperativeMatrixTypeByIndex(matrixTypeCount, &invalidType) == + SLANG_E_INVALID_ARG); + + SLANG_CHECK( + metadata->getCooperativeMatrixCombinationByIndex(0, nullptr) == SLANG_E_INVALID_ARG); + + slang::CooperativeMatrixCombination invalidCombination = {}; + SLANG_CHECK( + metadata->getCooperativeMatrixCombinationByIndex(combinationCount, &invalidCombination) == + SLANG_E_INVALID_ARG); +} + +static void _validateVectorTypeMetadata( + slang::ICooperativeTypesMetadata* metadata, + const slang::CooperativeVectorType* expectedTypes, + int expectedTypeCount) +{ + auto typeCount = metadata->getCooperativeVectorTypeCount(); + + List foundCounts; + foundCounts.setCount(expectedTypeCount); + for (int i = 0; i < expectedTypeCount; ++i) + foundCounts[i] = 0; + + SLANG_CHECK(typeCount == SlangUInt(expectedTypeCount)); + + for (SlangUInt i = 0; i < typeCount; ++i) + { + slang::CooperativeVectorType type = {}; + SLANG_CHECK(metadata->getCooperativeVectorTypeByIndex(i, &type) == SLANG_OK); + SLANG_CHECK(type.componentType != SLANG_SCALAR_TYPE_NONE); + + int matchedIndex = -1; + for (int j = 0; j < expectedTypeCount; ++j) + { + if (type.componentType == expectedTypes[j].componentType) + { + matchedIndex = j; + break; + } + } + + SLANG_CHECK(matchedIndex != -1); + if (matchedIndex != -1) + { + SLANG_CHECK(type.maxSize == expectedTypes[matchedIndex].maxSize); + SLANG_CHECK(type.usedForTrainingOp == expectedTypes[matchedIndex].usedForTrainingOp); + foundCounts[matchedIndex]++; + } + } + + for (int j = 0; j < expectedTypeCount; ++j) + SLANG_CHECK(foundCounts[j] == 1); + + SLANG_CHECK(metadata->getCooperativeVectorTypeByIndex(0, nullptr) == SLANG_E_INVALID_ARG); + + slang::CooperativeVectorType invalidType = {}; + SLANG_CHECK( + metadata->getCooperativeVectorTypeByIndex(typeCount, &invalidType) == SLANG_E_INVALID_ARG); +} + +static void _validateVectorCombinationMetadata( + slang::ICooperativeTypesMetadata* metadata, + const slang::CooperativeVectorCombination* expectedCombinations, + int expectedCombinationCount) +{ + auto combinationCount = metadata->getCooperativeVectorCombinationCount(); + + _checkListContainsEachExpectedExactlyOnce( + combinationCount, + [&](SlangUInt i) + { + slang::CooperativeVectorCombination combination = {}; + SLANG_CHECK( + metadata->getCooperativeVectorCombinationByIndex(i, &combination) == SLANG_OK); + SLANG_CHECK(combination.inputType != SLANG_SCALAR_TYPE_NONE); + SLANG_CHECK(combination.inputInterpretation != SLANG_SCALAR_TYPE_NONE); + SLANG_CHECK(combination.matrixInterpretation != SLANG_SCALAR_TYPE_NONE); + SLANG_CHECK(combination.resultType != SLANG_SCALAR_TYPE_NONE); + return combination; + }, + expectedCombinations, + expectedCombinationCount); + + SLANG_CHECK( + metadata->getCooperativeVectorCombinationByIndex(0, nullptr) == SLANG_E_INVALID_ARG); + + slang::CooperativeVectorCombination invalidCombination = {}; + SLANG_CHECK( + metadata->getCooperativeVectorCombinationByIndex(combinationCount, &invalidCombination) == + SLANG_E_INVALID_ARG); +} + +static const char* const kSpirvCoopMatCaps[] = {"spvCooperativeMatrixKHR"}; +static const char* const kSpirvCoopVecCaps[] = {"spvCooperativeVectorNV"}; +static const char* const kSpirvCoopVecTrainingCaps[] = { + "spvCooperativeVectorNV", + "spvCooperativeVectorTrainingNV"}; +static const char* const kCudaOptixCoopVecCaps[] = {"optix_coopvec"}; + +static const CooperativeMetadataTargetDesc kCooperativeMatrixSubgroupTargets[] = { + {"spirv", SLANG_SPIRV, "spirv_1_6", kSpirvCoopMatCaps, 1}, + {"cuda", SLANG_CUDA_SOURCE, nullptr, nullptr, 0}, +}; + +static const CooperativeMetadataTargetDesc kCooperativeMatrixWorkgroupTargets[] = { + {"spirv", SLANG_SPIRV, "spirv_1_6", kSpirvCoopMatCaps, 1}, +}; + +static const CooperativeMetadataTargetDesc kCooperativeVectorTargets[] = { + {"spirv", SLANG_SPIRV, "spirv_1_6", kSpirvCoopVecCaps, 1}, + {"hlsl", SLANG_HLSL, "sm_6_9", nullptr, 0}, + {"cuda_optix", SLANG_CUDA_SOURCE, nullptr, kCudaOptixCoopVecCaps, 1}, +}; + +static const CooperativeMetadataTargetDesc kCooperativeVectorTrainingTargets[] = { + {"spirv", SLANG_SPIRV, "spirv_1_6", kSpirvCoopVecTrainingCaps, 2}, + {"hlsl", SLANG_HLSL, "sm_6_9", nullptr, 0}, + {"cuda_optix", SLANG_CUDA_SOURCE, nullptr, kCudaOptixCoopVecCaps, 1}, +}; + +// Cooperative vectors are lowered before metadata collection for plain CUDA targets. +// See the lowerCooperativeVectors() dispatch in source/slang/slang-emit.cpp. +static const CooperativeMetadataTargetDesc kCooperativeVectorLoweringTargets[] = { + {"cuda", SLANG_CUDA_SOURCE, nullptr, nullptr, 0}, +}; + +SLANG_UNIT_TEST(cooperativeMatrixSubgroupTypeMetadata) +{ + const char* subgroupSource = R"( +using namespace linalg; + +RWStructuredBuffer outputBuffer; +RWStructuredBuffer outputBufferInt; + +[shader("compute")] +[numthreads(32,1,1)] +void computeMain() +{ + let d = coopMatMulAdd( + CoopMat(1.0), + CoopMat(2.0), + CoopMat(3.0) + ); + d.Store(outputBuffer, 0, 16); + + let dInt = coopMatMulAdd( + CoopMat(1), + CoopMat(2), + CoopMat(3) + ); + dInt.Store(outputBufferInt, 0, 16); + + let dIntSat = coopMatMulAdd( + CoopMat(1), + CoopMat(2), + CoopMat(3) + ); + dIntSat.Store(outputBufferInt, 16, 16); +} +)"; + + static const slang::CooperativeMatrixType expectedSubgroupTypes[] = { + {.componentType = SLANG_SCALAR_TYPE_FLOAT16, + .scope = SLANG_SCOPE_WAVE, + .rowCount = 16, + .columnCount = 16, + .use = SLANG_COOPERATIVE_MATRIX_USE_A}, + {.componentType = SLANG_SCALAR_TYPE_FLOAT16, + .scope = SLANG_SCOPE_WAVE, + .rowCount = 16, + .columnCount = 16, + .use = SLANG_COOPERATIVE_MATRIX_USE_B}, + {.componentType = SLANG_SCALAR_TYPE_FLOAT32, + .scope = SLANG_SCOPE_WAVE, + .rowCount = 16, + .columnCount = 16, + .use = SLANG_COOPERATIVE_MATRIX_USE_ACCUMULATOR}, + {.componentType = SLANG_SCALAR_TYPE_INT8, + .scope = SLANG_SCOPE_WAVE, + .rowCount = 16, + .columnCount = 16, + .use = SLANG_COOPERATIVE_MATRIX_USE_A}, + {.componentType = SLANG_SCALAR_TYPE_INT8, + .scope = SLANG_SCOPE_WAVE, + .rowCount = 16, + .columnCount = 16, + .use = SLANG_COOPERATIVE_MATRIX_USE_B}, + {.componentType = SLANG_SCALAR_TYPE_INT32, + .scope = SLANG_SCOPE_WAVE, + .rowCount = 16, + .columnCount = 16, + .use = SLANG_COOPERATIVE_MATRIX_USE_ACCUMULATOR}, + }; + + static const slang::CooperativeMatrixCombination expectedSubgroupCombinations[] = { + {.m = 16, + .n = 16, + .k = 16, + .componentTypeA = SLANG_SCALAR_TYPE_FLOAT16, + .componentTypeB = SLANG_SCALAR_TYPE_FLOAT16, + .componentTypeC = SLANG_SCALAR_TYPE_FLOAT32, + .componentTypeResult = SLANG_SCALAR_TYPE_FLOAT32, + .saturate = false, + .scope = SLANG_SCOPE_WAVE}, + {.m = 16, + .n = 16, + .k = 16, + .componentTypeA = SLANG_SCALAR_TYPE_INT8, + .componentTypeB = SLANG_SCALAR_TYPE_INT8, + .componentTypeC = SLANG_SCALAR_TYPE_INT32, + .componentTypeResult = SLANG_SCALAR_TYPE_INT32, + .saturate = false, + .scope = SLANG_SCOPE_WAVE}, + {.m = 16, + .n = 16, + .k = 16, + .componentTypeA = SLANG_SCALAR_TYPE_INT8, + .componentTypeB = SLANG_SCALAR_TYPE_INT8, + .componentTypeC = SLANG_SCALAR_TYPE_INT32, + .componentTypeResult = SLANG_SCALAR_TYPE_INT32, + .saturate = true, + .scope = SLANG_SCOPE_WAVE}, + }; + + for (const auto& target : kCooperativeMatrixSubgroupTargets) + { + auto metadata = _compileAndGetCooperativeMetadata( + subgroupSource, + "coopMatrixSubgroupTypeModule", + target); + SLANG_CHECK(metadata != nullptr); + if (!metadata) + return; + + _validateMatrixMetadata( + metadata, + expectedSubgroupTypes, + int(SLANG_COUNT_OF(expectedSubgroupTypes)), + expectedSubgroupCombinations, + int(SLANG_COUNT_OF(expectedSubgroupCombinations))); + + SLANG_CHECK(metadata->getCooperativeVectorTypeCount() == 0); + SLANG_CHECK(metadata->getCooperativeVectorCombinationCount() == 0); + } +} + +SLANG_UNIT_TEST(cooperativeMatrixWorkgroupTypeMetadata) +{ + const char* workgroupSource = R"( +using namespace linalg; + +RWStructuredBuffer outputBufferWorkgroup; + +[shader("compute")] +[numthreads(32,1,1)] +void computeMain() +{ + let dWorkgroup = coopMatMulAdd( + CoopMat(1.0), + CoopMat(2.0), + CoopMat(3.0) + ); + dWorkgroup.Store(outputBufferWorkgroup, 0, 16); +} +)"; + + static const slang::CooperativeMatrixType expectedWorkgroupTypes[] = { + {.componentType = SLANG_SCALAR_TYPE_FLOAT16, + .scope = SLANG_SCOPE_THREAD_GROUP, + .rowCount = 16, + .columnCount = 16, + .use = SLANG_COOPERATIVE_MATRIX_USE_A}, + {.componentType = SLANG_SCALAR_TYPE_FLOAT16, + .scope = SLANG_SCOPE_THREAD_GROUP, + .rowCount = 16, + .columnCount = 16, + .use = SLANG_COOPERATIVE_MATRIX_USE_B}, + {.componentType = SLANG_SCALAR_TYPE_FLOAT32, + .scope = SLANG_SCOPE_THREAD_GROUP, + .rowCount = 16, + .columnCount = 16, + .use = SLANG_COOPERATIVE_MATRIX_USE_ACCUMULATOR}, + }; + + static const slang::CooperativeMatrixCombination expectedWorkgroupCombinations[] = { + {.m = 16, + .n = 16, + .k = 16, + .componentTypeA = SLANG_SCALAR_TYPE_FLOAT16, + .componentTypeB = SLANG_SCALAR_TYPE_FLOAT16, + .componentTypeC = SLANG_SCALAR_TYPE_FLOAT32, + .componentTypeResult = SLANG_SCALAR_TYPE_FLOAT32, + .saturate = false, + .scope = SLANG_SCOPE_THREAD_GROUP}, + }; + + for (const auto& target : kCooperativeMatrixWorkgroupTargets) + { + auto metadata = _compileAndGetCooperativeMetadata( + workgroupSource, + "coopMatrixWorkgroupTypeModule", + target); + SLANG_CHECK(metadata != nullptr); + if (!metadata) + return; + + _validateMatrixMetadata( + metadata, + expectedWorkgroupTypes, + int(SLANG_COUNT_OF(expectedWorkgroupTypes)), + expectedWorkgroupCombinations, + int(SLANG_COUNT_OF(expectedWorkgroupCombinations))); + + SLANG_CHECK(metadata->getCooperativeVectorTypeCount() == 0); + SLANG_CHECK(metadata->getCooperativeVectorCombinationCount() == 0); + } +} + +SLANG_UNIT_TEST(cooperativeVectorTypeMetadata) +{ + const char* source = R"( +using namespace linalg; + +RWStructuredBuffer outputBuffer; +ByteAddressBuffer input; +RWByteAddressBuffer matrix; +RWByteAddressBuffer bias; + +[shader("compute")] +[numthreads(1,1,1)] +void computeMain() +{ + let vec4 = coopVecLoad<4, int8_t>(input); + let vec8 = coopVecLoad<8, int8_t>(input); + let packedVec = coopVecLoad<1, uint>(input); + + constexpr const CoopVecComponentType signedInt8 = CoopVecComponentType::SignedInt8; + constexpr const CoopVecComponentType signedInt32 = CoopVecComponentType::SignedInt32; + constexpr const CoopVecMatrixLayout rowMajor = CoopVecMatrixLayout::RowMajor; + constexpr const bool noTranspose = false; + + let resultA = coopVecMatMulAdd( + vec4, + signedInt8, + matrix, + 0, + signedInt8, + bias, + 0, + signedInt32, + rowMajor, + noTranspose, + 4); + + let resultB = coopVecMatMulAdd( + vec8, + CoopVecComponentType::SignedInt8, + matrix, + 0, + CoopVecComponentType::SignedInt8, + bias, + 0, + CoopVecComponentType::SignedInt32, + CoopVecMatrixLayout::RowMajor, + true, + 8); + + let resultC = coopVecMatMul( + vec4, + CoopVecComponentType::SignedInt8, + matrix, + 0, + CoopVecComponentType::SignedInt8, + CoopVecMatrixLayout::RowMajor, + false, + 4); + + let resultPacked = coopVecMatMulPacked( + packedVec, + CoopVecComponentType::SignedInt8Packed, + 4, + matrix, + 0, + CoopVecComponentType::SignedInt8, + CoopVecMatrixLayout::RowMajor, + false, + 4); + + for (int i = 0; i < resultA.getCount(); ++i) + { + outputBuffer[i] = resultA[i] + resultB[i] + resultC[i] + resultPacked[i]; + } +} +)"; + + static const slang::CooperativeVectorType expectedTypes[] = { + {.componentType = SLANG_SCALAR_TYPE_INT8, .maxSize = 8, .usedForTrainingOp = false}, + {.componentType = SLANG_SCALAR_TYPE_INT32, .maxSize = 4, .usedForTrainingOp = false}, + {.componentType = SLANG_SCALAR_TYPE_UINT32, .maxSize = 1, .usedForTrainingOp = false}, + }; + + static const slang::CooperativeVectorCombination expectedCombinations[] = { + {.inputType = SLANG_SCALAR_TYPE_INT8, + .inputInterpretation = SLANG_SCALAR_TYPE_INT8, + .inputPackingFactor = 1, + .matrixInterpretation = SLANG_SCALAR_TYPE_INT8, + .biasInterpretation = SLANG_SCALAR_TYPE_INT32, + .resultType = SLANG_SCALAR_TYPE_INT32, + .transpose = false}, + {.inputType = SLANG_SCALAR_TYPE_INT8, + .inputInterpretation = SLANG_SCALAR_TYPE_INT8, + .inputPackingFactor = 1, + .matrixInterpretation = SLANG_SCALAR_TYPE_INT8, + .biasInterpretation = SLANG_SCALAR_TYPE_INT32, + .resultType = SLANG_SCALAR_TYPE_INT32, + .transpose = true}, + {.inputType = SLANG_SCALAR_TYPE_INT8, + .inputInterpretation = SLANG_SCALAR_TYPE_INT8, + .inputPackingFactor = 1, + .matrixInterpretation = SLANG_SCALAR_TYPE_INT8, + .biasInterpretation = SLANG_SCALAR_TYPE_NONE, + .resultType = SLANG_SCALAR_TYPE_INT32, + .transpose = false}, + {.inputType = SLANG_SCALAR_TYPE_UINT32, + .inputInterpretation = SLANG_SCALAR_TYPE_INT8, + .inputPackingFactor = 4, + .matrixInterpretation = SLANG_SCALAR_TYPE_INT8, + .biasInterpretation = SLANG_SCALAR_TYPE_NONE, + .resultType = SLANG_SCALAR_TYPE_INT32, + .transpose = false}, + }; + + for (const auto& target : kCooperativeVectorTargets) + { + auto metadata = _compileAndGetCooperativeMetadata(source, "coopVectorTypeModule", target); + SLANG_CHECK(metadata != nullptr); + if (!metadata) + return; + + _validateVectorTypeMetadata(metadata, expectedTypes, int(SLANG_COUNT_OF(expectedTypes))); + + _validateVectorCombinationMetadata( + metadata, + expectedCombinations, + int(SLANG_COUNT_OF(expectedCombinations))); + + SLANG_CHECK(metadata->getCooperativeMatrixTypeCount() == 0); + SLANG_CHECK(metadata->getCooperativeMatrixCombinationCount() == 0); + } +} + +SLANG_UNIT_TEST(cooperativeVectorTrainingMetadata) +{ + const char* source = R"( +using namespace linalg; + +ByteAddressBuffer input; +RWByteAddressBuffer matrix; +RWByteAddressBuffer output; + +[shader("compute")] +[numthreads(1,1,1)] +void computeMain() +{ + let v = coopVecLoad<4, float>(input); + constexpr const CoopVecComponentType float16Type = CoopVecComponentType::Float16; + + coopVecOuterProductAccumulate( + v, + v, + matrix, + 0, + 8, + CoopVecMatrixLayout::TrainingOptimal, + float16Type); + + coopVecReduceSumAccumulate(v, output, 0); +} +)"; + + static const slang::CooperativeVectorType expectedTypes[] = { + {.componentType = SLANG_SCALAR_TYPE_FLOAT16, .maxSize = 0, .usedForTrainingOp = true}, + {.componentType = SLANG_SCALAR_TYPE_FLOAT32, .maxSize = 4, .usedForTrainingOp = true}, + }; + + for (const auto& target : kCooperativeVectorTrainingTargets) + { + auto metadata = + _compileAndGetCooperativeMetadata(source, "coopVectorTrainingTypeModule", target); + SLANG_CHECK(metadata != nullptr); + if (!metadata) + return; + + _validateVectorTypeMetadata(metadata, expectedTypes, int(SLANG_COUNT_OF(expectedTypes))); + + SLANG_CHECK(metadata->getCooperativeVectorCombinationCount() == 0); + SLANG_CHECK(metadata->getCooperativeMatrixTypeCount() == 0); + SLANG_CHECK(metadata->getCooperativeMatrixCombinationCount() == 0); + } +} + +SLANG_UNIT_TEST(cooperativeVectorMixedTrainingAndNonTrainingMetadata) +{ + const char* source = R"( +using namespace linalg; + +ByteAddressBuffer input; +RWByteAddressBuffer matrix; +RWByteAddressBuffer bias; +RWByteAddressBuffer reduceOutput; +RWStructuredBuffer resultBuffer; + +[shader("compute")] +[numthreads(1,1,1)] +void computeMain() +{ + let vec8 = coopVecLoad<8, float>(input); + let vec4 = coopVecLoad<4, float>(input); + + let result = coopVecMatMulAdd( + vec8, + CoopVecComponentType::Float32, + matrix, + 0, + CoopVecComponentType::Float32, + bias, + 0, + CoopVecComponentType::Float32, + CoopVecMatrixLayout::RowMajor, + false, + 8); + + coopVecReduceSumAccumulate(vec4, reduceOutput, 0); + + for (int i = 0; i < result.getCount(); ++i) + { + resultBuffer[i] = result[i]; + } +} +)"; + + static const slang::CooperativeVectorType expectedTypes[] = { + {.componentType = SLANG_SCALAR_TYPE_FLOAT32, .maxSize = 8, .usedForTrainingOp = true}, + }; + + static const slang::CooperativeVectorCombination expectedCombinations[] = { + {.inputType = SLANG_SCALAR_TYPE_FLOAT32, + .inputInterpretation = SLANG_SCALAR_TYPE_FLOAT32, + .inputPackingFactor = 1, + .matrixInterpretation = SLANG_SCALAR_TYPE_FLOAT32, + .biasInterpretation = SLANG_SCALAR_TYPE_FLOAT32, + .resultType = SLANG_SCALAR_TYPE_FLOAT32, + .transpose = false}, + }; + + for (const auto& target : kCooperativeVectorTrainingTargets) + { + auto metadata = _compileAndGetCooperativeMetadata( + source, + "coopVectorMixedTrainingNonTrainingTypeModule", + target); + SLANG_CHECK(metadata != nullptr); + if (!metadata) + return; + + _validateVectorTypeMetadata(metadata, expectedTypes, int(SLANG_COUNT_OF(expectedTypes))); + _validateVectorCombinationMetadata( + metadata, + expectedCombinations, + int(SLANG_COUNT_OF(expectedCombinations))); + + SLANG_CHECK(metadata->getCooperativeMatrixTypeCount() == 0); + SLANG_CHECK(metadata->getCooperativeMatrixCombinationCount() == 0); + } +} + +SLANG_UNIT_TEST(cooperativeMetadataLoweredVectorTarget) +{ + const char* source = R"( +using namespace linalg; + +ByteAddressBuffer input; +RWByteAddressBuffer matrix; +RWByteAddressBuffer bias; +RWStructuredBuffer resultBuffer; + +[shader("compute")] +[numthreads(1,1,1)] +void computeMain() +{ + let vec8 = coopVecLoad<8, float>(input); + + let result = coopVecMatMulAdd( + vec8, + CoopVecComponentType::Float32, + matrix, + 0, + CoopVecComponentType::Float32, + bias, + 0, + CoopVecComponentType::Float32, + CoopVecMatrixLayout::RowMajor, + false, + 8); + + for (int i = 0; i < result.getCount(); ++i) + { + resultBuffer[i] = result[i]; + } +} +)"; + + for (const auto& target : kCooperativeVectorLoweringTargets) + { + auto metadata = + _compileAndGetCooperativeMetadata(source, "coopLoweredVectorTargetModule", target); + SLANG_CHECK(metadata != nullptr); + if (!metadata) + return; + + SLANG_CHECK(metadata->getCooperativeMatrixTypeCount() == 0); + SLANG_CHECK(metadata->getCooperativeMatrixCombinationCount() == 0); + SLANG_CHECK(metadata->getCooperativeVectorTypeCount() == 0); + SLANG_CHECK(metadata->getCooperativeVectorCombinationCount() == 0); + } +} + +SLANG_UNIT_TEST(cooperativeMetadataEmptyShader) +{ + const char* source = R"( +RWStructuredBuffer outputBuffer; + +[shader("compute")] +[numthreads(1,1,1)] +void computeMain() +{ + outputBuffer[0] = 1.0f; +} +)"; + + static const CooperativeMetadataTargetDesc targets[] = { + {"spirv", SLANG_SPIRV, "spirv_1_6", kSpirvCoopMatCaps, 1}, + {"hlsl", SLANG_HLSL, "sm_6_9", nullptr, 0}, + }; + + for (const auto& target : targets) + { + auto metadata = _compileAndGetCooperativeMetadata(source, "coopEmptyModule", target); + SLANG_CHECK(metadata != nullptr); + if (!metadata) + return; + + SLANG_CHECK(metadata->getCooperativeMatrixTypeCount() == 0); + SLANG_CHECK(metadata->getCooperativeMatrixCombinationCount() == 0); + SLANG_CHECK(metadata->getCooperativeVectorTypeCount() == 0); + SLANG_CHECK(metadata->getCooperativeVectorCombinationCount() == 0); + } +}