diff --git a/source/slang/slang-emit-cuda.cpp b/source/slang/slang-emit-cuda.cpp index 420dd3efcde..45a3427084a 100644 --- a/source/slang/slang-emit-cuda.cpp +++ b/source/slang/slang-emit-cuda.cpp @@ -825,7 +825,6 @@ bool CUDASourceEmitter::tryEmitInstStmtImpl(IRInst* inst) } case kIROp_CoopVecMatMulAdd: { - auto coopVecMatMulAdd = cast(inst); if (!isOptixCoopVec) { getSink()->diagnose(Diagnostics::UnsupportedTargetIntrinsic{ diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 0bc58b704f3..595409284c0 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -8313,6 +8313,41 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex return result; } + // Emit an operand that satisfies cooperative vector SPIRV opcodes, which require a pointer + // to an array type (not a pointer to a struct). + // + // - Buffer resource case (ByteAddressBuffer / StructuredBuffer): after SPIRV legalization + // the global param has type ptr-to-struct{runtimeArray}. Emit an OpAccessChain with + // index 0 to pierce through the wrapper struct and expose the runtime array. + // + // - Ptr case: the pointer already points directly to the unsized array, so return it + // as-is. + SpvInst* emitBufferPtrAsArrayPtr(SpvInstParent* parent, IRInst* bufferVal) + { + IRBuilder builder(bufferVal); + auto addressSpace = + isSpirv14OrLater() ? AddressSpace::StorageBuffer : AddressSpace::Uniform; + IRPtrTypeBase* bufPtrType = cast(bufferVal->getDataType()); + // If the pointee is not a struct, the pointer already targets an array directly + // (e.g. Ptr) — use it without modification. + IRStructType* bufType = as(bufPtrType->getValueType()); + if (!bufType) + return ensureInst(bufferVal); + // The struct's first (and only) field is the runtime array of elements. + IRArrayTypeBase* arrayType = + cast(bufType->getFields().getFirst()->getFieldType()); + return emitOpAccessChain( + parent, + nullptr, + builder.getPtrType( + arrayType, + AccessQualifier::ReadWrite, + addressSpace, + bufPtrType->getDataLayout()), + bufferVal, + makeArray(emitIntConstant(0, builder.getIntType()))); + } + SpvInst* emitGetBufferPtr(SpvInstParent* parent, IRInst* inst) { IRBuilder builder(inst); @@ -9194,12 +9229,12 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex emitMappedCoopVecComponentTypeOperand( coopVecMatMulAdd->getInputInterpretation(), coopVecMatMulAdd->getInputInterpretationPackingFactor()); - emitOperand(coopVecMatMulAdd->getMatrixPtr()); + emitOperand(emitBufferPtrAsArrayPtr(parent, coopVecMatMulAdd->getMatrixPtr())); emitOperand(coopVecMatMulAdd->getMatrixOffset()); emitMappedCoopVecComponentTypeOperand(coopVecMatMulAdd->getMatrixInterpretation()); if (hasBias) { - emitOperand(coopVecMatMulAdd->getBiasPtr()); + emitOperand(emitBufferPtrAsArrayPtr(parent, coopVecMatMulAdd->getBiasPtr())); emitOperand(coopVecMatMulAdd->getBiasOffset()); emitMappedCoopVecComponentTypeOperand( coopVecMatMulAdd->getBiasInterpretation()); @@ -9229,7 +9264,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex SpvOpCooperativeVectorOuterProductAccumulateNV, [&]() { - emitOperand(outerProduct->getMatrixPtr()); + emitOperand(emitBufferPtrAsArrayPtr(parent, outerProduct->getMatrixPtr())); emitOperand(outerProduct->getMatrixOffset()); emitOperand(outerProduct->getA()); emitOperand(outerProduct->getB()); @@ -9246,13 +9281,16 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex auto reduceSum = cast(inst); - return emitInst( + return emitInstCustomOperandFunc( parent, inst, SpvOpCooperativeVectorReduceSumAccumulateNV, - reduceSum->getBufferPtr(), - reduceSum->getOffset(), - reduceSum->getValue()); + [&]() + { + emitOperand(emitBufferPtrAsArrayPtr(parent, reduceSum->getBufferPtr())); + emitOperand(reduceSum->getOffset()); + emitOperand(reduceSum->getValue()); + }); } SpvInst* emitSplat(SpvInstParent* parent, IRInst* inst, IRInst* scalar, IRIntegerValue numElems)