diff --git a/llvm/docs/SPIRVUsage.rst b/llvm/docs/SPIRVUsage.rst index 28e919fdf516a..8f7ac71f8026b 100644 --- a/llvm/docs/SPIRVUsage.rst +++ b/llvm/docs/SPIRVUsage.rst @@ -179,6 +179,8 @@ list of supported SPIR-V extensions, sorted alphabetically by their extension na - Introduces two new storage classes that are subclasses of the CrossWorkgroup storage class that provides additional information that can enable optimization. * - ``SPV_INTEL_variable_length_array`` - Allows to allocate local arrays whose number of elements is unknown at compile time. + * - ``SPV_INTEL_joint_matrix`` + - Adds few matrix capabilities on top of SPV_KHR_cooperative_matrix extension, such as matrix prefetch, get element coordinate and checked load/store/construct instructions, tensor float 32 and bfloat type interpretations for multuply-add instruction. * - ``SPV_KHR_bit_instructions`` - Enables bit instructions to be used by SPIR-V modules without requiring the Shader capability. * - ``SPV_KHR_expect_assume`` diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.cpp b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.cpp index 0f9a2a69e0739..67bf459615249 100644 --- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.cpp +++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.cpp @@ -137,8 +137,12 @@ getCapabilitiesEnabledByExtension(SPIRV::Extension::Extension Extension) { CapabilityList Capabilities; while (Entry && - Entry->Category == SPIRV::OperandCategory::CapabilityOperand && - Entry->ReqExtension == Extension) { + Entry->Category == SPIRV::OperandCategory::CapabilityOperand) { + // Some capabilities' codes might go not in order. + if (Entry->ReqExtension != Extension) { + ++Entry; + continue; + } Capabilities.push_back( static_cast(Entry->Value)); ++Entry; diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h index 44625793e9413..2437fbb820a36 100644 --- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h +++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h @@ -207,6 +207,16 @@ namespace Opcode { #include "SPIRVGenTables.inc" } // namespace Opcode +namespace CooperativeMatrixLayout { +#define GET_CooperativeMatrixLayout_DECL +#include "SPIRVGenTables.inc" +} // namespace CooperativeMatrixLayout + +namespace CooperativeMatrixOperands { +#define GET_CooperativeMatrixOperands_DECL +#include "SPIRVGenTables.inc" +} // namespace CooperativeMatrixOperands + struct ExtendedBuiltin { StringRef Name; InstructionSet::InstructionSet Set; diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp index ff8759755e517..2ee0c79b8f7c1 100644 --- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp +++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp @@ -211,6 +211,34 @@ void SPIRVInstPrinter::printInst(const MCInst *MI, uint64_t Address, // are part of the variable value. printOpConstantVarOps(MI, NumFixedOps - 1, OS); break; + case SPIRV::OpCooperativeMatrixMulAddKHR: { + const unsigned NumOps = MI->getNumOperands(); + if (NumFixedOps == NumOps) + break; + + OS << ' '; + const unsigned MulAddOp = MI->getOperand(FirstVariableIndex).getImm(); + if (MulAddOp == 0) { + printSymbolicOperand< + OperandCategory::CooperativeMatrixOperandsOperand>( + MI, FirstVariableIndex, OS); + } else { + std::string Buffer; + for (unsigned Mask = 0x1; + Mask != SPIRV::CooperativeMatrixOperands:: + MatrixResultBFloat16ComponentsINTEL; + Mask <<= 1) { + if (MulAddOp & Mask) { + if (!Buffer.empty()) + Buffer += '|'; + Buffer += getSymbolicOperandMnemonic( + OperandCategory::CooperativeMatrixOperandsOperand, Mask); + } + } + OS << Buffer; + } + break; + } default: printRemainingVariableOps(MI, NumFixedOps, OS); break; diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp index 45a49674d4ca2..9b6c2a849edce 100644 --- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp @@ -1969,15 +1969,49 @@ static bool generateCoopMatrInst(const SPIRV::IncomingCall *Call, const SPIRV::DemangledBuiltin *Builtin = Call->Builtin; unsigned Opcode = SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode; - bool IsSet = Opcode != SPIRV::OpCooperativeMatrixStoreKHR; + bool IsSet = Opcode != SPIRV::OpCooperativeMatrixStoreKHR && + Opcode != SPIRV::OpCooperativeMatrixStoreCheckedINTEL && + Opcode != SPIRV::OpCooperativeMatrixPrefetchINTEL; unsigned ArgSz = Call->Arguments.size(); unsigned LiteralIdx = 0; - if (Opcode == SPIRV::OpCooperativeMatrixLoadKHR && ArgSz > 3) - LiteralIdx = 3; - else if (Opcode == SPIRV::OpCooperativeMatrixStoreKHR && ArgSz > 4) - LiteralIdx = 4; + switch (Opcode) { + // Memory operand is optional and is literal. + case SPIRV::OpCooperativeMatrixLoadKHR: + LiteralIdx = ArgSz > 3 ? 3 : 0; + break; + case SPIRV::OpCooperativeMatrixStoreKHR: + LiteralIdx = ArgSz > 4 ? 4 : 0; + break; + case SPIRV::OpCooperativeMatrixLoadCheckedINTEL: + LiteralIdx = ArgSz > 7 ? 7 : 0; + break; + case SPIRV::OpCooperativeMatrixStoreCheckedINTEL: + LiteralIdx = ArgSz > 8 ? 8 : 0; + break; + // Cooperative Matrix Operands operand is optional and is literal. + case SPIRV::OpCooperativeMatrixMulAddKHR: + LiteralIdx = ArgSz > 3 ? 3 : 0; + break; + }; + SmallVector ImmArgs; MachineRegisterInfo *MRI = MIRBuilder.getMRI(); + if (Opcode == SPIRV::OpCooperativeMatrixPrefetchINTEL) { + const uint32_t CacheLevel = getConstFromIntrinsic(Call->Arguments[3], MRI); + auto MIB = MIRBuilder.buildInstr(SPIRV::OpCooperativeMatrixPrefetchINTEL) + .addUse(Call->Arguments[0]) // pointer + .addUse(Call->Arguments[1]) // rows + .addUse(Call->Arguments[2]) // columns + .addImm(CacheLevel) // cache level + .addUse(Call->Arguments[4]); // memory layout + if (ArgSz > 5) + MIB.addUse(Call->Arguments[5]); // stride + if (ArgSz > 6) { + const uint32_t MemOp = getConstFromIntrinsic(Call->Arguments[6], MRI); + MIB.addImm(MemOp); // memory operand + } + return true; + } if (LiteralIdx > 0) ImmArgs.push_back(getConstFromIntrinsic(Call->Arguments[LiteralIdx], MRI)); Register TypeReg = GR->getSPIRVTypeID(Call->ReturnType); diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td index dc2da4a3a5647..e29013d28aafe 100644 --- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td +++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td @@ -695,6 +695,13 @@ defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixStoreKHR", OpenCL_std, C defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixMulAddKHR", OpenCL_std, CoopMatr, 3, 4, OpCooperativeMatrixMulAddKHR>; defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixLengthKHR", OpenCL_std, CoopMatr, 1, 1, OpCooperativeMatrixLengthKHR>; +// Cooperative Matrix Intel builtin records: +defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixPrefetchINTEL", OpenCL_std, CoopMatr, 5, 7, OpCooperativeMatrixPrefetchINTEL>; +defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixLoadCheckedINTEL", OpenCL_std, CoopMatr, 6, 8, OpCooperativeMatrixLoadCheckedINTEL>; +defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixStoreCheckedINTEL", OpenCL_std, CoopMatr, 7, 9, OpCooperativeMatrixStoreCheckedINTEL>; +defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixConstructCheckedINTEL", OpenCL_std, CoopMatr, 5, 5, OpCooperativeMatrixConstructCheckedINTEL>; +defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixGetElementCoordINTEL", OpenCL_std, CoopMatr, 2, 2, OpCooperativeMatrixGetElementCoordINTEL>; + //===----------------------------------------------------------------------===// // Class defining a work/sub group builtin that should be translated into a // SPIR-V instruction using the defined properties. diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp index e78fc5ce18707..fb05c1fdbd1e3 100644 --- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp @@ -51,6 +51,8 @@ static const std::map> SPIRV::Extension::Extension::SPV_INTEL_subgroups}, {"SPV_INTEL_media_block_io", SPIRV::Extension::Extension::SPV_INTEL_media_block_io}, + {"SPV_INTEL_joint_matrix", + SPIRV::Extension::Extension::SPV_INTEL_joint_matrix}, {"SPV_KHR_uniform_group_instructions", SPIRV::Extension::Extension::SPV_KHR_uniform_group_instructions}, {"SPV_KHR_no_integer_wrap_decoration", diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td index 53f1b644a9498..d95803fea56a5 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td +++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td @@ -895,6 +895,23 @@ def OpCooperativeMatrixMulAddKHR: Op<4459, (outs ID:$res), def OpCooperativeMatrixLengthKHR: Op<4460, (outs ID:$res), (ins TYPE:$type, ID:$coop_matr_type), "$res = OpCooperativeMatrixLengthKHR $type $coop_matr_type">; +// SPV_INTEL_joint_matrix +def OpCooperativeMatrixLoadCheckedINTEL: Op<6193, (outs ID:$res), + (ins TYPE:$resType, ID:$pointer, ID:$xOffset, ID:$yOffset, ID:$memory_layout, ID:$height, ID:$width, variable_ops), + "$res = OpCooperativeMatrixLoadCheckedINTEL $resType $pointer $xOffset $yOffset $memory_layout $height $width">; +def OpCooperativeMatrixStoreCheckedINTEL: Op<6194, (outs), + (ins ID:$pointer, ID:$xOffset, ID:$yOffset, ID:$objectToStore, ID:$memory_layout, ID:$height, ID:$width, variable_ops), + "OpCooperativeMatrixStoreCheckedINTEL $pointer $xOffset $yOffset $objectToStore $memory_layout $height $width">; +def OpCooperativeMatrixConstructCheckedINTEL: Op<6195, (outs ID:$res), + (ins TYPE:$resType, ID:$xOffset, ID:$yOffset, ID:$height, ID:$width, ID:$value), + "$res = OpCooperativeMatrixConstructCheckedINTEL $resType $xOffset $yOffset $height $width $value">; +def OpCooperativeMatrixGetElementCoordINTEL: Op<6440, (outs ID:$res), + (ins TYPE:$resType, ID:$matrix, ID:$index), + "$res = OpCooperativeMatrixGetElementCoordINTEL $resType $matrix $index">; +def OpCooperativeMatrixPrefetchINTEL: Op<6449, (outs), + (ins ID:$pointer, ID:$rows, ID:$columns, i32imm:$cacheLevel, ID:$memory_layout, variable_ops), + "OpCooperativeMatrixPrefetchINTEL $pointer $rows $columns $cacheLevel $memory_layout">; + // SPV_EXT_arithmetic_fence def OpArithmeticFenceEXT: Op<6145, (outs ID:$res), (ins TYPE:$type, ID:$target), "$res = OpArithmeticFenceEXT $type $target">; diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp index 2054081476315..4ee71c703f90b 100644 --- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp @@ -1437,6 +1437,138 @@ void addInstrRequirements(const MachineInstr &MI, Reqs.addCapability(SPIRV::Capability::SplitBarrierINTEL); } break; + case SPIRV::OpCooperativeMatrixMulAddKHR: { + if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix)) + report_fatal_error("Cooperative matrix instructions require the " + "following SPIR-V extension: " + "SPV_KHR_cooperative_matrix", + false); + Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix); + Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR); + constexpr unsigned MulAddMaxSize = 6; + if (MI.getNumOperands() != MulAddMaxSize) + break; + const int64_t CoopOperands = MI.getOperand(MulAddMaxSize - 1).getImm(); + if (CoopOperands & + SPIRV::CooperativeMatrixOperands::MatrixAAndBTF32ComponentsINTEL) { + if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix)) + report_fatal_error("MatrixAAndBTF32ComponentsINTEL type interpretation " + "require the following SPIR-V extension: " + "SPV_INTEL_joint_matrix", + false); + Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix); + Reqs.addCapability( + SPIRV::Capability::CooperativeMatrixTF32ComponentTypeINTEL); + } + if (CoopOperands & SPIRV::CooperativeMatrixOperands:: + MatrixAAndBBFloat16ComponentsINTEL || + CoopOperands & + SPIRV::CooperativeMatrixOperands::MatrixCBFloat16ComponentsINTEL || + CoopOperands & SPIRV::CooperativeMatrixOperands:: + MatrixResultBFloat16ComponentsINTEL) { + if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix)) + report_fatal_error("***BF16ComponentsINTEL type interpretations " + "require the following SPIR-V extension: " + "SPV_INTEL_joint_matrix", + false); + Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix); + Reqs.addCapability( + SPIRV::Capability::CooperativeMatrixBFloat16ComponentTypeINTEL); + } + break; + } + case SPIRV::OpCooperativeMatrixLoadKHR: + case SPIRV::OpCooperativeMatrixStoreKHR: + case SPIRV::OpCooperativeMatrixLoadCheckedINTEL: + case SPIRV::OpCooperativeMatrixStoreCheckedINTEL: + case SPIRV::OpCooperativeMatrixPrefetchINTEL: { + if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix)) + report_fatal_error("Cooperative matrix instructions require the " + "following SPIR-V extension: " + "SPV_KHR_cooperative_matrix", + false); + Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix); + Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR); + + // Check Layout operand in case if it's not a standard one and add the + // appropriate capability. + std::unordered_map LayoutToInstMap = { + {SPIRV::OpCooperativeMatrixLoadKHR, 3}, + {SPIRV::OpCooperativeMatrixStoreKHR, 2}, + {SPIRV::OpCooperativeMatrixLoadCheckedINTEL, 5}, + {SPIRV::OpCooperativeMatrixStoreCheckedINTEL, 4}, + {SPIRV::OpCooperativeMatrixPrefetchINTEL, 4}}; + + const auto OpCode = MI.getOpcode(); + const unsigned LayoutNum = LayoutToInstMap[OpCode]; + Register RegLayout = MI.getOperand(LayoutNum).getReg(); + const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo(); + MachineInstr *MILayout = MRI.getUniqueVRegDef(RegLayout); + if (MILayout->getOpcode() == SPIRV::OpConstantI) { + const unsigned LayoutVal = MILayout->getOperand(2).getImm(); + if (LayoutVal == + static_cast(SPIRV::CooperativeMatrixLayout::PackedINTEL)) { + if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix)) + report_fatal_error("PackedINTEL layout require the following SPIR-V " + "extension: SPV_INTEL_joint_matrix", + false); + Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix); + Reqs.addCapability(SPIRV::Capability::PackedCooperativeMatrixINTEL); + } + } + + // Nothing to do. + if (OpCode == SPIRV::OpCooperativeMatrixLoadKHR || + OpCode == SPIRV::OpCooperativeMatrixStoreKHR) + break; + + std::string InstName; + switch (OpCode) { + case SPIRV::OpCooperativeMatrixPrefetchINTEL: + InstName = "OpCooperativeMatrixPrefetchINTEL"; + break; + case SPIRV::OpCooperativeMatrixLoadCheckedINTEL: + InstName = "OpCooperativeMatrixLoadCheckedINTEL"; + break; + case SPIRV::OpCooperativeMatrixStoreCheckedINTEL: + InstName = "OpCooperativeMatrixStoreCheckedINTEL"; + break; + } + + if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix)) { + const std::string ErrorMsg = + InstName + " instruction requires the " + "following SPIR-V extension: SPV_INTEL_joint_matrix"; + report_fatal_error(ErrorMsg.c_str(), false); + } + Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix); + if (OpCode == SPIRV::OpCooperativeMatrixPrefetchINTEL) { + Reqs.addCapability(SPIRV::Capability::CooperativeMatrixPrefetchINTEL); + break; + } + Reqs.addCapability( + SPIRV::Capability::CooperativeMatrixCheckedInstructionsINTEL); + break; + } + case SPIRV::OpCooperativeMatrixConstructCheckedINTEL: + if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix)) + report_fatal_error("OpCooperativeMatrixConstructCheckedINTEL " + "instructions require the following SPIR-V extension: " + "SPV_INTEL_joint_matrix", + false); + Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix); + Reqs.addCapability( + SPIRV::Capability::CooperativeMatrixCheckedInstructionsINTEL); + break; + case SPIRV::OpCooperativeMatrixGetElementCoordINTEL: + if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix)) + report_fatal_error("OpCooperativeMatrixGetElementCoordINTEL requires the " + "following SPIR-V extension: SPV_INTEL_joint_matrix", + false); + Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix); + Reqs.addCapability( + SPIRV::Capability::CooperativeMatrixInvocationInstructionsINTEL); + break; case SPIRV::OpKill: { Reqs.addCapability(SPIRV::Capability::Shader); } break; diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td index a3a88acdd6c6a..745d1e1aec67a 100644 --- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td +++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td @@ -170,6 +170,8 @@ def GroupOperationOperand : OperandCategory; def KernelEnqueueFlagsOperand : OperandCategory; def KernelProfilingInfoOperand : OperandCategory; def OpcodeOperand : OperandCategory; +def CooperativeMatrixLayoutOperand : OperandCategory; +def CooperativeMatrixOperandsOperand : OperandCategory; //===----------------------------------------------------------------------===// // Multiclass used to define Extesions enum values and at the same time @@ -305,6 +307,7 @@ defm SPV_INTEL_global_variable_fpga_decorations : ExtensionOperand<110>; defm SPV_KHR_cooperative_matrix : ExtensionOperand<111>; defm SPV_EXT_arithmetic_fence : ExtensionOperand<112>; defm SPV_EXT_optnone : ExtensionOperand<113>; +defm SPV_INTEL_joint_matrix : ExtensionOperand<114>; //===----------------------------------------------------------------------===// // Multiclass used to define Capabilities enum values and at the same time @@ -492,6 +495,12 @@ defm CacheControlsINTEL : CapabilityOperand<6441, 0, 0, [SPV_INTEL_cache_control defm CooperativeMatrixKHR : CapabilityOperand<6022, 0, 0, [SPV_KHR_cooperative_matrix], []>; defm ArithmeticFenceEXT : CapabilityOperand<6144, 0, 0, [SPV_EXT_arithmetic_fence], []>; defm SplitBarrierINTEL : CapabilityOperand<6141, 0, 0, [SPV_INTEL_split_barrier], []>; +defm CooperativeMatrixCheckedInstructionsINTEL : CapabilityOperand<6192, 0, 0, [SPV_INTEL_joint_matrix], []>; +defm CooperativeMatrixPrefetchINTEL : CapabilityOperand<6411, 0, 0, [SPV_INTEL_joint_matrix], []>; +defm PackedCooperativeMatrixINTEL : CapabilityOperand<6434, 0, 0, [SPV_INTEL_joint_matrix], []>; +defm CooperativeMatrixInvocationInstructionsINTEL : CapabilityOperand<6435, 0, 0, [SPV_INTEL_joint_matrix], []>; +defm CooperativeMatrixTF32ComponentTypeINTEL : CapabilityOperand<6436, 0, 0, [SPV_INTEL_joint_matrix], []>; +defm CooperativeMatrixBFloat16ComponentTypeINTEL : CapabilityOperand<6437, 0, 0, [SPV_INTEL_joint_matrix], []>; //===----------------------------------------------------------------------===// // Multiclass used to define SourceLanguage enum values and at the same time @@ -1649,3 +1658,62 @@ defm GenericCastToPtr : OpcodeOperand<122>; defm Bitcast : OpcodeOperand<124>; defm ConvertPtrToU : OpcodeOperand<117>; defm ConvertUToPtr : OpcodeOperand<120>; + +//===----------------------------------------------------------------------===// +// Multiclass used to define Cooperative Matrix Layout enum values and at the +// same time SymbolicOperand entries extensions and capabilities. +//===----------------------------------------------------------------------===// + +def CooperativeMatrixLayout : GenericEnum, Operand { + let FilterClass = "CooperativeMatrixLayout"; + let NameField = "Name"; + let ValueField = "Value"; +} + +class CooperativeMatrixLayout value> { + string Name = name; + bits<32> Value = value; +} + +multiclass CooperativeMatrixLayoutOperand value, list reqExtensions, list reqCapabilities> { + def : CooperativeMatrixLayout; + defm : SymbolicOperandWithRequirements; +} + +defm RowMajorKHR : CooperativeMatrixLayoutOperand<0x0, [SPV_KHR_cooperative_matrix], [CooperativeMatrixKHR]>; +defm ColumnMajorKHR : CooperativeMatrixLayoutOperand<0x1, [SPV_KHR_cooperative_matrix], [CooperativeMatrixKHR]>; +defm PackedINTEL : CooperativeMatrixLayoutOperand<0x2, [SPV_INTEL_joint_matrix], [PackedCooperativeMatrixINTEL]>; + +//===----------------------------------------------------------------------===// +// Multiclass used to define Cooperative Matrix Operands enum values and at the +// same time SymbolicOperand entries with string mnemonics, extensions and +// capabilities. +//===----------------------------------------------------------------------===// + +def CooperativeMatrixOperands : GenericEnum, Operand { + let FilterClass = "CooperativeMatrixOperands"; + let NameField = "Name"; + let ValueField = "Value"; + let PrintMethod = !strconcat("printSymbolicOperand"); +} + +class CooperativeMatrixOperands value> { + string Name = name; + bits<32> Value = value; +} + +multiclass CooperativeMatrixOperandsOperand value, list reqExtensions, list reqCapabilities> { + def : CooperativeMatrixOperands; + defm : SymbolicOperandWithRequirements; +} + +defm NoneKHR : CooperativeMatrixOperandsOperand<0x0, [SPV_KHR_cooperative_matrix], [CooperativeMatrixKHR]>; +defm MatrixASignedComponentsKHR : CooperativeMatrixOperandsOperand<0x1, [SPV_KHR_cooperative_matrix], [CooperativeMatrixKHR]>; +defm MatrixBSignedComponentsKHR : CooperativeMatrixOperandsOperand<0x2, [SPV_KHR_cooperative_matrix], [CooperativeMatrixKHR]>; +defm MatrixCSignedComponentsKHR : CooperativeMatrixOperandsOperand<0x4, [SPV_KHR_cooperative_matrix], [CooperativeMatrixKHR]>; +defm MatrixResultSignedComponentsKHR : CooperativeMatrixOperandsOperand<0x8, [SPV_KHR_cooperative_matrix], [CooperativeMatrixKHR]>; +defm SaturatingAccumulationKHR : CooperativeMatrixOperandsOperand<0x10, [SPV_KHR_cooperative_matrix], [CooperativeMatrixKHR]>; +defm MatrixAAndBTF32ComponentsINTEL : CooperativeMatrixOperandsOperand<0x20, [SPV_INTEL_joint_matrix], [CooperativeMatrixTF32ComponentTypeINTEL]>; +defm MatrixAAndBBFloat16ComponentsINTEL : CooperativeMatrixOperandsOperand<0x40, [SPV_INTEL_joint_matrix], [CooperativeMatrixBFloat16ComponentTypeINTEL]>; +defm MatrixCBFloat16ComponentsINTEL : CooperativeMatrixOperandsOperand<0x80, [SPV_INTEL_joint_matrix], [CooperativeMatrixBFloat16ComponentTypeINTEL]>; +defm MatrixResultBFloat16ComponentsINTEL : CooperativeMatrixOperandsOperand<0x100, [SPV_INTEL_joint_matrix], [CooperativeMatrixBFloat16ComponentTypeINTEL]>; diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_joint_matrix/cooperative_matrix_bf16.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_joint_matrix/cooperative_matrix_bf16.ll new file mode 100644 index 0000000000000..c0b23a324992e --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_joint_matrix/cooperative_matrix_bf16.ll @@ -0,0 +1,32 @@ +; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_cooperative_matrix %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR + +; CHECK-ERROR: LLVM ERROR: ***BF16ComponentsINTEL type interpretations require the following SPIR-V extension: SPV_INTEL_joint_matrix + +; RUN: llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_cooperative_matrix,+SPV_INTEL_joint_matrix %s -o - | FileCheck %s + +; CHECK-DAG: Capability CooperativeMatrixKHR +; CHECK-DAG: Extension "SPV_KHR_cooperative_matrix" +; CHECK-DAG: Extension "SPV_INTEL_joint_matrix" +; CHECK-DAG: CooperativeMatrixBFloat16ComponentTypeINTEL + +; CHECK: OpCooperativeMatrixMulAddKHR %[[#]] %[[#]] %[[#]] %[[#]] MatrixAAndBBFloat16ComponentsINTEL + +define weak_odr dso_local spir_kernel void @_ZTSZZ15matrix_multiply(ptr addrspace(1) noundef align 1 %_arg_accA, ptr addrspace(1) noundef align 1 %_arg_accB, ptr addrspace(1) noundef align 1 %_arg_accC, i64 noundef %_arg_N, i64 noundef %_arg_K, i32 noundef %_arg_Initvalue) { +entry: + %matrixC = tail call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(float 0.0) + %matrixA = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 3, 12, 48, 0) @_Z32__spirv_CooperativeMatrixLoadKHR_1(ptr addrspace(1) noundef %_arg_accA, i32 noundef 0, i64 noundef %_arg_K, i32 noundef 1) + %matrixB = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i16, 2, 48, 12, 1) @_Z32__spirv_CooperativeMatrixLoadKHR_1(ptr addrspace(1) noundef %_arg_accB, i32 noundef 1, i64 noundef %_arg_K) + %res = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) @_Z34__spirv_CooperativeMatrixMulAddKHR(target("spirv.CooperativeMatrixKHR", i16, 3, 12, 48, 0) noundef %matrixA, target("spirv.CooperativeMatrixKHR", i16, 2, 48, 12, 1) noundef %matrixB, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) noundef %matrixC, i32 noundef 64) + tail call spir_func void @_Z33__spirv_CooperativeMatrixStoreKHR(ptr addrspace(1) noundef %_arg_accC, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) noundef %res, i32 noundef 0, i64 noundef %_arg_N, i32 noundef 1) + ret void +} + +declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(float noundef) + +declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) @_Z34__spirv_CooperativeMatrixMulAddKHR(target("spirv.CooperativeMatrixKHR", i16, 3, 12, 48, 0) noundef, target("spirv.CooperativeMatrixKHR", i16, 2, 48, 12, 1) noundef, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) noundef, i32 noundef) + +declare dso_local spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 48, 0) @_Z32__spirv_CooperativeMatrixLoadKHR_1(ptr addrspace(3), i32, i64, i32) + +declare dso_local spir_func target("spirv.CooperativeMatrixKHR", float, 3, 48, 12, 1) @_Z32__spirv_CooperativeMatrixLoadKHR_2(ptr addrspace(3), i32, i64) + +declare dso_local spir_func void @_Z33__spirv_CooperativeMatrixStoreKHR(ptr addrspace(3), target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2), i32, i64, i32) diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_joint_matrix/cooperative_matrix_checked.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_joint_matrix/cooperative_matrix_checked.ll new file mode 100644 index 0000000000000..a4b2c4be5084b --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_joint_matrix/cooperative_matrix_checked.ll @@ -0,0 +1,46 @@ +; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_cooperative_matrix %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR + +; CHECK-ERROR: LLVM ERROR: OpCooperativeMatrixConstructCheckedINTEL instructions require the following SPIR-V extension: SPV_INTEL_joint_matrix + +; RUN: llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_cooperative_matrix,+SPV_INTEL_joint_matrix %s -o - | FileCheck %s + +; CHECK-DAG: Capability CooperativeMatrixKHR +; CHECK-DAG: Capability CooperativeMatrixCheckedInstructionsINTEL +; CHECK-DAG: Extension "SPV_KHR_cooperative_matrix" +; CHECK-DAG: Extension "SPV_INTEL_joint_matrix" +; CHECK-DAG: %[[#Int8Ty:]] = OpTypeInt 8 0 +; CHECK-DAG: %[[#Int32Ty:]] = OpTypeInt 32 0 +; CHECK-DAG: %[[#Const12:]] = OpConstant %[[#Int32Ty]] 12 +; CHECK-DAG: %[[#Const48:]] = OpConstant %[[#Int32Ty]] 48 +; CHECK-DAG: %[[#Const0:]] = OpConstant %[[#Int32Ty]] 0 +; CHECK-DAG: %[[#Const3:]] = OpConstant %[[#Int32Ty]] 3 +; CHECK-DAG: %[[#Const2:]] = OpConstant %[[#Int32Ty]] 2 +; CHECK-DAG: %[[#Const1:]] = OpConstant %[[#Int32Ty]] 1 +; CHECK-DAG: %[[#MatTy1:]] = OpTypeCooperativeMatrixKHR %[[#Int32Ty]] %[[#Const3]] %[[#Const12]] %[[#Const12]] %[[#Const2]] +; CHECK-DAG: %[[#MatTy2:]] = OpTypeCooperativeMatrixKHR %[[#Int8Ty]] %[[#Const3]] %[[#Const12]] %[[#Const48]] %[[#Const0]] +; CHECK-DAG: %[[#MatTy3:]] = OpTypeCooperativeMatrixKHR %[[#Int8Ty]] %[[#Const2]] %[[#Const48]] %[[#Const12]] %[[#Const1]] +; CHECK: OpCooperativeMatrixConstructCheckedINTEL %[[#MatTy1]] +; CHECK: OpCooperativeMatrixLoadCheckedINTEL %[[#MatTy2]] +; CHECK: OpCooperativeMatrixLoadCheckedINTEL %[[#MatTy3]] +; CHECK: OpCooperativeMatrixMulAddKHR %[[#MatTy1]] +; CHECK: OpCooperativeMatrixStoreCheckedINTEL + +define weak_odr dso_local spir_kernel void @_ZTSZZ15matrix_multiply(ptr addrspace(1) noundef align 1 %_arg_accA, ptr addrspace(1) noundef align 1 %_arg_accB, ptr addrspace(1) noundef align 1 %_arg_accC, i64 noundef %_arg_N, i64 noundef %_arg_K, i32 noundef %_arg_Initvalue) { +entry: + %matrixC = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z46__spirv_CooperativeMatrixConstructCheckedINTEL(i32 noundef 4, i32 noundef 4, i32 noundef 12, i32 noundef 12, i32 noundef %_arg_Initvalue) + %matrixA = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i8, 3, 12, 48, 0) @_Z41__spirv_CooperativeMatrixLoadCheckedINTEL_1(ptr addrspace(1) noundef %_arg_accA, i32 noundef 0, i32 noundef 0, i32 noundef 0, i32 noundef 12, i32 noundef 48, i64 noundef %_arg_K, i32 noundef 1) + %matrixB = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i8, 2, 48, 12, 1) @_Z41__spirv_CooperativeMatrixLoadCheckedINTEL_2(ptr addrspace(1) noundef %_arg_accB, i32 noundef 0, i32 noundef 0, i32 noundef 1, i32 noundef 48, i32 noundef 12, i64 noundef %_arg_K) + %res = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z34__spirv_CooperativeMatrixMulAddKHR(target("spirv.CooperativeMatrixKHR", i8, 3, 12, 48, 0) noundef %matrixA, target("spirv.CooperativeMatrixKHR", i8, 2, 48, 12, 1) noundef %matrixB, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) noundef %matrixC, i32 noundef 12) + tail call spir_func void @_Z42__spirv_CooperativeMatrixStoreCheckedINTEL(ptr addrspace(1) noundef %_arg_accC, i32 noundef 0, i32 noundef 0, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) noundef %res, i32 noundef 0, i32 noundef 12, i32 noundef 12, i64 noundef %_arg_N, i32 noundef 1) + ret void +} + +declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z46__spirv_CooperativeMatrixConstructCheckedINTEL(i32 noundef, i32 noundef, i32 noundef, i32 noundef, i32 noundef) + +declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i8, 3, 12, 48, 0) @_Z41__spirv_CooperativeMatrixLoadCheckedINTEL_1(ptr addrspace(4) noundef, i32 noundef, i32 noundef, i32 noundef, i32 noundef, i32 noundef, i64 noundef, i32 noundef) + +declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i8, 2, 48, 12, 1) @_Z41__spirv_CooperativeMatrixLoadCheckedINTEL_2(ptr addrspace(4) noundef, i32 noundef, i32 noundef, i32 noundef, i32 noundef, i32 noundef, i64 noundef) + +declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z34__spirv_CooperativeMatrixMulAddKHR(target("spirv.CooperativeMatrixKHR", i8, 3, 12, 48, 0) noundef, target("spirv.CooperativeMatrixKHR", i8, 2, 48, 12, 1) noundef, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) noundef, i32 noundef) + +declare dso_local spir_func void @_Z42__spirv_CooperativeMatrixStoreCheckedINTEL(ptr addrspace(4) noundef, i32 noundef, i32 noundef, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) noundef, i32 noundef, i32 noundef, i32 noundef, i64 noundef, i32 noundef) diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_joint_matrix/cooperative_matrix_get_coord.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_joint_matrix/cooperative_matrix_get_coord.ll new file mode 100644 index 0000000000000..e7ec5e186d16a --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_joint_matrix/cooperative_matrix_get_coord.ll @@ -0,0 +1,24 @@ +; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_cooperative_matrix %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR + +; CHECK-ERROR: LLVM ERROR: OpCooperativeMatrixGetElementCoordINTEL requires the following SPIR-V extension: SPV_INTEL_joint_matrix + +; RUN: llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_cooperative_matrix,+SPV_INTEL_joint_matrix %s -o - | FileCheck %s + +; CHECK-DAG: Capability CooperativeMatrixKHR +; CHECK-DAG: Capability CooperativeMatrixInvocationInstructionsINTEL +; CHECK-DAG: Extension "SPV_KHR_cooperative_matrix" +; CHECK-DAG: Extension "SPV_INTEL_joint_matrix" +; CHECK-DAG: %[[#MatrixTy:]] = OpTypeCooperativeMatrixKHR + +; CHECK: %[[#Matrix:]] = OpCompositeConstruct %[[#MatrixTy]] +; CHECK: %[[#]] = OpCooperativeMatrixGetElementCoordINTEL %[[#]] %[[#Matrix]] %[[#]] + +define weak_odr dso_local spir_kernel void @_ZTSZZ15matrix_multiply(i32 noundef %_idx, i32 noundef %_arg_Initvalue) { +entry: + %matrixC = tail call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(float 0.0) + %coord = tail call spir_func <2 x i32> @_Z45__spirv_CooperativeMatrixGetElementCoordINTEL(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) %matrixC, i32 noundef %_idx) + ret void +} +declare dso_local spir_func <2 x i32> @_Z45__spirv_CooperativeMatrixGetElementCoordINTEL(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2), i32 noundef) + +declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(i32 noundef) local_unnamed_addr diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_joint_matrix/cooperative_matrix_packed.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_joint_matrix/cooperative_matrix_packed.ll new file mode 100644 index 0000000000000..8b336df1c2c20 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_joint_matrix/cooperative_matrix_packed.ll @@ -0,0 +1,58 @@ +; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_cooperative_matrix %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR + +; CHECK-ERROR: LLVM ERROR: PackedINTEL layout require the following SPIR-V extension: SPV_INTEL_joint_matrix + +; RUN: llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_cooperative_matrix,+SPV_INTEL_joint_matrix %s -o - | FileCheck %s + +; CHECK-DAG: Capability CooperativeMatrixKHR +; CHECK-DAG: Capability PackedCooperativeMatrixINTEL +; CHECK-DAG: Capability CooperativeMatrixCheckedInstructionsINTEL +; CHECK-DAG: Extension "SPV_KHR_cooperative_matrix" +; CHECK-DAG: Extension "SPV_INTEL_joint_matrix" +; CHECK-DAG: %[[#Int32Ty:]] = OpTypeInt 32 0 +; CHECK-DAG: %[[#Const2:]] = OpConstant %[[#Int32Ty]] 2 + +; CHECK: OpCooperativeMatrixLoadKHR %[[#]] %[[#]] %[[#Const2]] +; CHECK: OpCooperativeMatrixLoadKHR %[[#]] %[[#]] %[[#Const2]] +; CHECK: OpCooperativeMatrixStoreKHR %[[#]] %[[#]] %[[#Const2]] +; CHECK: OpCooperativeMatrixLoadCheckedINTEL %[[#]] %[[#]] %[[#]] %[[#]] %[[#Const2]] +; CHECK: OpCooperativeMatrixLoadCheckedINTEL %[[#]] %[[#]] %[[#]] %[[#]] %[[#Const2]] +; CHECK: OpCooperativeMatrixStoreCheckedINTEL %[[#]] %[[#]] %[[#]] %[[#]] %[[#Const2]] + +define weak_odr dso_local spir_kernel void @_ZTSZZ15matrix_multiply(ptr addrspace(1) noundef align 1 %_arg_accA, ptr addrspace(1) noundef align 1 %_arg_accB, ptr addrspace(1) noundef align 1 %_arg_accC, i64 noundef %_arg_N, i64 noundef %_arg_K, i32 noundef %_arg_Initvalue) { +entry: + %matrixC = tail call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(i32 %_arg_Initvalue) + %matrixA = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i8, 3, 12, 48, 0) @_Z32__spirv_CooperativeMatrixLoadKHR_1(ptr addrspace(1) noundef %_arg_accA, i32 noundef 2, i64 noundef %_arg_K, i32 noundef 1) + %matrixB = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i8, 2, 48, 12, 1) @_Z32__spirv_CooperativeMatrixLoadKHR_1(ptr addrspace(1) noundef %_arg_accB, i32 noundef 2, i64 noundef %_arg_K) + %res = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z34__spirv_CooperativeMatrixMulAddKHR(target("spirv.CooperativeMatrixKHR", i8, 3, 12, 48, 0) noundef %matrixA, target("spirv.CooperativeMatrixKHR", i8, 2, 48, 12, 1) noundef %matrixB, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) noundef %matrixC, i32 noundef 12) + tail call spir_func void @_Z33__spirv_CooperativeMatrixStoreKHR(ptr addrspace(1) noundef %_arg_accC, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) noundef %res, i32 noundef 2, i64 noundef %_arg_N, i32 noundef 1) + ret void +} + +define weak_odr dso_local spir_kernel void @_ZTSZZ23matrix_multiply_checked(ptr addrspace(1) noundef align 1 %_arg_accA, ptr addrspace(1) noundef align 1 %_arg_accB, ptr addrspace(1) noundef align 1 %_arg_accC, i64 noundef %_arg_N, i64 noundef %_arg_K, i32 noundef %_arg_Initvalue) { +entry: + %matrixC = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z46__spirv_CooperativeMatrixConstructCheckedINTEL(i32 noundef 4, i32 noundef 4, i32 noundef 12, i32 noundef 12, i32 noundef %_arg_Initvalue) + %matrixA = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i8, 3, 12, 48, 0) @_Z41__spirv_CooperativeMatrixLoadCheckedINTEL_1(ptr addrspace(1) noundef %_arg_accA, i32 noundef 0, i32 noundef 0, i32 noundef 2, i32 noundef 12, i32 noundef 48, i64 noundef %_arg_K, i32 noundef 1) + %matrixB = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i8, 2, 48, 12, 1) @_Z41__spirv_CooperativeMatrixLoadCheckedINTEL_2(ptr addrspace(1) noundef %_arg_accB, i32 noundef 0, i32 noundef 0, i32 noundef 2, i32 noundef 48, i32 noundef 12, i64 noundef %_arg_K) + %res = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z34__spirv_CooperativeMatrixMulAddKHR(target("spirv.CooperativeMatrixKHR", i8, 3, 12, 48, 0) noundef %matrixA, target("spirv.CooperativeMatrixKHR", i8, 2, 48, 12, 1) noundef %matrixB, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) noundef %matrixC, i32 noundef 12) + tail call spir_func void @_Z42__spirv_CooperativeMatrixStoreCheckedINTEL(ptr addrspace(1) noundef %_arg_accC, i32 noundef 0, i32 noundef 0, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) noundef %res, i32 noundef 2, i32 noundef 12, i32 noundef 12, i64 noundef %_arg_N, i32 noundef 1) + ret void +} + +declare dso_local spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(i32) + +declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z46__spirv_CooperativeMatrixConstructCheckedINTEL(i32 noundef, i32 noundef, i32 noundef, i32 noundef, i32 noundef) + +declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i8, 3, 12, 48, 0) @_Z41__spirv_CooperativeMatrixLoadCheckedINTEL_1(ptr addrspace(4) noundef, i32 noundef, i32 noundef, i32 noundef, i32 noundef, i32 noundef, i64 noundef, i32 noundef) + +declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i8, 2, 48, 12, 1) @_Z41__spirv_CooperativeMatrixLoadCheckedINTEL_2(ptr addrspace(4) noundef, i32 noundef, i32 noundef, i32 noundef, i32 noundef, i32 noundef, i64 noundef) + +declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z34__spirv_CooperativeMatrixMulAddKHR(target("spirv.CooperativeMatrixKHR", i8, 3, 12, 48, 0) noundef, target("spirv.CooperativeMatrixKHR", i8, 2, 48, 12, 1) noundef, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) noundef, i32 noundef) + +declare dso_local spir_func void @_Z42__spirv_CooperativeMatrixStoreCheckedINTEL(ptr addrspace(4) noundef, i32 noundef, i32 noundef, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) noundef, i32 noundef, i32 noundef, i32 noundef, i64 noundef, i32 noundef) + +declare dso_local spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 48, 0) @_Z32__spirv_CooperativeMatrixLoadKHR_1(ptr addrspace(3), i32, i64, i32) + +declare dso_local spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 48, 12, 1) @_Z32__spirv_CooperativeMatrixLoadKHR_2(ptr addrspace(3), i32, i64) + +declare dso_local spir_func void @_Z33__spirv_CooperativeMatrixStoreKHR(ptr addrspace(3), target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2), i32, i64, i32) diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_joint_matrix/cooperative_matrix_prefetch.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_joint_matrix/cooperative_matrix_prefetch.ll new file mode 100644 index 0000000000000..8573e09284403 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_joint_matrix/cooperative_matrix_prefetch.ll @@ -0,0 +1,27 @@ +; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_cooperative_matrix %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR + +; CHECK-ERROR: LLVM ERROR: OpCooperativeMatrixPrefetchINTEL instruction requires the following SPIR-V extension: SPV_INTEL_joint_matrix + +; RUN: llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_cooperative_matrix,+SPV_INTEL_joint_matrix %s -o - | FileCheck %s + +; CHECK-DAG: Capability CooperativeMatrixPrefetchINTEL +; CHECK-DAG: Extension "SPV_KHR_cooperative_matrix" +; CHECK-DAG: OpExtension "SPV_INTEL_joint_matrix" + +; CHECK-DAG: OpCooperativeMatrixPrefetchINTEL %[[#]] %[[#]] %[[#]] 0 %[[#]] +; CHECK-DAG: OpCooperativeMatrixPrefetchINTEL %[[#]] %[[#]] %[[#]] 0 %[[#]] %[[#]] +; CHECK-DAG: OpCooperativeMatrixPrefetchINTEL %[[#]] %[[#]] %[[#]] 0 %[[#]] %[[#]] 1 + +define weak_odr dso_local spir_kernel void @_ZTSZZ15matrix_multiply(ptr addrspace(1) noundef align 1 %_arg_accA, ptr addrspace(1) noundef align 1 %_arg_accB, ptr addrspace(1) noundef align 1 %_arg_accC, i64 noundef %_arg_K) { +entry: + tail call spir_func void @_Z38__spirv_CooperativeMatrixPrefetchINTELPU3AS4ciiii(ptr addrspace(1) noundef %_arg_accA, i32 noundef 12, i32 noundef 48, i32 noundef 0, i32 noundef 0) + tail call spir_func void @_Z38__spirv_CooperativeMatrixPrefetchINTELPU3AS4ciiiil(ptr addrspace(1) noundef %_arg_accB, i32 noundef 12, i32 noundef 48, i32 noundef 0, i32 noundef 0, i64 noundef %_arg_K) + tail call spir_func void @_Z38__spirv_CooperativeMatrixPrefetchINTELPU3AS4ciiiili(ptr addrspace(1) noundef %_arg_accC, i32 noundef 12, i32 noundef 48, i32 noundef 0, i32 noundef 0, i64 noundef %_arg_K, i32 1) + ret void +} + +declare dso_local spir_func void @_Z38__spirv_CooperativeMatrixPrefetchINTELPU3AS4ciiii(ptr addrspace(1) noundef, i32 noundef, i32 noundef, i32 noundef, i32 noundef) + +declare dso_local spir_func void @_Z38__spirv_CooperativeMatrixPrefetchINTELPU3AS4ciiiil(ptr addrspace(1) noundef, i32 noundef, i32 noundef, i32 noundef, i32 noundef, i64 noundef) + +declare dso_local spir_func void @_Z38__spirv_CooperativeMatrixPrefetchINTELPU3AS4ciiiili(ptr addrspace(1) noundef, i32 noundef, i32 noundef, i32 noundef, i32 noundef, i64 noundef, i32 noundef) diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_joint_matrix/cooperative_matrix_tf32.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_joint_matrix/cooperative_matrix_tf32.ll new file mode 100644 index 0000000000000..d3ada306d2c2f --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_joint_matrix/cooperative_matrix_tf32.ll @@ -0,0 +1,32 @@ +; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_cooperative_matrix %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR + +; CHECK-ERROR: LLVM ERROR: MatrixAAndBTF32ComponentsINTEL type interpretation require the following SPIR-V extension: SPV_INTEL_joint_matrix + +; RUN: llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_cooperative_matrix,+SPV_INTEL_joint_matrix %s -o - | FileCheck %s + +; CHECK-DAG: Capability CooperativeMatrixKHR +; CHECK-DAG: Extension "SPV_KHR_cooperative_matrix" +; CHECK-DAG: Extension "SPV_INTEL_joint_matrix" +; CHECK-DAG: OpCapability CooperativeMatrixTF32ComponentTypeINTEL + +; CHECK: OpCooperativeMatrixMulAddKHR %[[#]] %[[#]] %[[#]] %[[#]] MatrixAAndBTF32ComponentsINTEL + +define weak_odr dso_local spir_kernel void @_ZTSZZ15matrix_multiply(ptr addrspace(1) noundef align 1 %_arg_accA, ptr addrspace(1) noundef align 1 %_arg_accB, ptr addrspace(1) noundef align 1 %_arg_accC, i64 noundef %_arg_N, i64 noundef %_arg_K, float noundef %_arg_Initvalue) { +entry: + %matrixC = tail call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(float %_arg_Initvalue) + %matrixA = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 48, 0) @_Z32__spirv_CooperativeMatrixLoadKHR_1(ptr addrspace(1) noundef %_arg_accA, i32 noundef 0, i64 noundef %_arg_K, i32 noundef 1) + %matrixB = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 2, 48, 12, 1) @_Z32__spirv_CooperativeMatrixLoadKHR_1(ptr addrspace(1) noundef %_arg_accB, i32 noundef 1, i64 noundef %_arg_K) + %res = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) @_Z34__spirv_CooperativeMatrixMulAddKHR(target("spirv.CooperativeMatrixKHR", float, 3, 12, 48, 0) noundef %matrixA, target("spirv.CooperativeMatrixKHR", float, 2, 48, 12, 1) noundef %matrixB, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) noundef %matrixC, i32 noundef 32) + tail call spir_func void @_Z33__spirv_CooperativeMatrixStoreKHR(ptr addrspace(1) noundef %_arg_accC, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) noundef %res, i32 noundef 0, i64 noundef %_arg_N, i32 noundef 1) + ret void +} + +declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(float noundef) + +declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) @_Z34__spirv_CooperativeMatrixMulAddKHR(target("spirv.CooperativeMatrixKHR", float, 3, 12, 48, 0) noundef, target("spirv.CooperativeMatrixKHR", float, 2, 48, 12, 1) noundef, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 2) noundef, i32 noundef) + +declare dso_local spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 48, 0) @_Z32__spirv_CooperativeMatrixLoadKHR_1(ptr addrspace(3), i32, i64, i32) + +declare dso_local spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 48, 12, 1) @_Z32__spirv_CooperativeMatrixLoadKHR_2(ptr addrspace(3), i32, i64) + +declare dso_local spir_func void @_Z33__spirv_CooperativeMatrixStoreKHR(ptr addrspace(3), target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2), i32, i64, i32) diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_cooperative_matrix/cooperative_matrix.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_cooperative_matrix/cooperative_matrix.ll index 1c41c7331cda8..e290c1eaeabad 100644 --- a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_cooperative_matrix/cooperative_matrix.ll +++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_cooperative_matrix/cooperative_matrix.ll @@ -21,7 +21,7 @@ ; CHECK: %[[#Load1:]] = OpCooperativeMatrixLoadKHR %[[#MatTy2]] ; CHECK: OpCooperativeMatrixLengthKHR %[[#Int32Ty]] %[[#MatTy2:]] ; CHECK: OpCooperativeMatrixLoadKHR %[[#MatTy3]] -; CHECK: OpCooperativeMatrixMulAddKHR %[[#MatTy1]] +; CHECK: OpCooperativeMatrixMulAddKHR %[[#MatTy1]] %[[#]] %[[#]] %[[#]] MatrixCSignedComponentsKHR|MatrixResultSignedComponentsKHR ; CHECK: OpCooperativeMatrixStoreKHR define spir_kernel void @matr_mult(ptr addrspace(1) align 1 %_arg_accA, ptr addrspace(1) align 1 %_arg_accB, ptr addrspace(1) align 4 %_arg_accC, i64 %_arg_N, i64 %_arg_K) {