@@ -176,6 +176,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {
176176 ImmTyWaitVAVDst,
177177 ImmTyWaitVMVSrc,
178178 ImmTyBitOp3,
179+ ImmTyMatrixAFMT,
180+ ImmTyMatrixBFMT,
179181 ImmTyMatrixAReuse,
180182 ImmTyMatrixBReuse,
181183 ImmTyByteSel,
@@ -423,6 +425,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {
423425 bool isIndexKey8bit () const { return isImmTy (ImmTyIndexKey8bit); }
424426 bool isIndexKey16bit () const { return isImmTy (ImmTyIndexKey16bit); }
425427 bool isIndexKey32bit () const { return isImmTy (ImmTyIndexKey32bit); }
428+ bool isMatrixAFMT () const { return isImmTy (ImmTyMatrixAFMT); }
429+ bool isMatrixBFMT () const { return isImmTy (ImmTyMatrixBFMT); }
426430 bool isMatrixAReuse () const { return isImmTy (ImmTyMatrixAReuse); }
427431 bool isMatrixBReuse () const { return isImmTy (ImmTyMatrixBReuse); }
428432 bool isTFE () const { return isImmTy (ImmTyTFE); }
@@ -1174,6 +1178,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {
11741178 case ImmTyWaitVAVDst: OS << " WaitVAVDst" ; break ;
11751179 case ImmTyWaitVMVSrc: OS << " WaitVMVSrc" ; break ;
11761180 case ImmTyBitOp3: OS << " BitOp3" ; break ;
1181+ case ImmTyMatrixAFMT: OS << " ImmTyMatrixAFMT" ; break ;
1182+ case ImmTyMatrixBFMT: OS << " ImmTyMatrixBFMT" ; break ;
11771183 case ImmTyMatrixAReuse: OS << " ImmTyMatrixAReuse" ; break ;
11781184 case ImmTyMatrixBReuse: OS << " ImmTyMatrixBReuse" ; break ;
11791185 case ImmTyByteSel: OS << " ByteSel" ; break ;
@@ -1714,6 +1720,10 @@ class AMDGPUAsmParser : public MCTargetAsmParser {
17141720 ParseStatus parseIndexKey8bit (OperandVector &Operands);
17151721 ParseStatus parseIndexKey16bit (OperandVector &Operands);
17161722 ParseStatus parseIndexKey32bit (OperandVector &Operands);
1723+ ParseStatus tryParseMatrixFMT (OperandVector &Operands, StringRef Name,
1724+ AMDGPUOperand::ImmTy Type);
1725+ ParseStatus parseMatrixAFMT (OperandVector &Operands);
1726+ ParseStatus parseMatrixBFMT (OperandVector &Operands);
17171727
17181728 ParseStatus parseDfmtNfmt (int64_t &Format);
17191729 ParseStatus parseUfmt (int64_t &Format);
@@ -1849,6 +1859,7 @@ class AMDGPUAsmParser : public MCTargetAsmParser {
18491859 const unsigned CPol);
18501860 bool validateTFE (const MCInst &Inst, const OperandVector &Operands);
18511861 std::optional<StringRef> validateLdsDirect (const MCInst &Inst);
1862+ bool validateWMMA (const MCInst &Inst, const OperandVector &Operands);
18521863 unsigned getConstantBusLimit (unsigned Opcode) const ;
18531864 bool usesConstantBus (const MCInst &Inst, unsigned OpIdx);
18541865 bool isInlineConstant (const MCInst &Inst, unsigned OpIdx) const ;
@@ -5409,6 +5420,37 @@ bool AMDGPUAsmParser::validateTFE(const MCInst &Inst,
54095420 return true ;
54105421}
54115422
5423+ bool AMDGPUAsmParser::validateWMMA (const MCInst &Inst,
5424+ const OperandVector &Operands) {
5425+ unsigned Opc = Inst.getOpcode ();
5426+ const MCRegisterInfo *TRI = getContext ().getRegisterInfo ();
5427+ const MCInstrDesc &Desc = MII.get (Opc);
5428+
5429+ auto validateFmt = [&](AMDGPU::OpName FmtOp, AMDGPU::OpName SrcOp) -> bool {
5430+ int FmtIdx = AMDGPU::getNamedOperandIdx (Opc, FmtOp);
5431+ if (FmtIdx == -1 )
5432+ return true ;
5433+ unsigned Fmt = Inst.getOperand (FmtIdx).getImm ();
5434+ int SrcIdx = AMDGPU::getNamedOperandIdx (Opc, SrcOp);
5435+ unsigned RegSize =
5436+ TRI->getRegClass (Desc.operands ()[SrcIdx].RegClass ).getSizeInBits ();
5437+
5438+ if (RegSize == AMDGPU::wmmaScaleF8F6F4FormatToNumRegs (Fmt) * 32 )
5439+ return true ;
5440+
5441+ static const char *FmtNames[] = {" MATRIX_FMT_FP8" , " MATRIX_FMT_BF8" ,
5442+ " MATRIX_FMT_FP6" , " MATRIX_FMT_BF6" ,
5443+ " MATRIX_FMT_FP4" };
5444+
5445+ Error (getRegLoc (mc2PseudoReg (Inst.getOperand (SrcIdx).getReg ()), Operands),
5446+ " wrong register tuple size for " + Twine (FmtNames[Fmt]));
5447+ return false ;
5448+ };
5449+
5450+ return validateFmt (AMDGPU::OpName::matrix_a_fmt, AMDGPU::OpName::src0) &&
5451+ validateFmt (AMDGPU::OpName::matrix_b_fmt, AMDGPU::OpName::src1);
5452+ }
5453+
54125454bool AMDGPUAsmParser::validateInstruction (const MCInst &Inst,
54135455 const SMLoc &IDLoc,
54145456 const OperandVector &Operands) {
@@ -5542,6 +5584,9 @@ bool AMDGPUAsmParser::validateInstruction(const MCInst &Inst,
55425584 if (!validateTFE (Inst, Operands)) {
55435585 return false ;
55445586 }
5587+ if (!validateWMMA (Inst, Operands)) {
5588+ return false ;
5589+ }
55455590
55465591 return true ;
55475592}
@@ -7215,6 +7260,26 @@ ParseStatus AMDGPUAsmParser::parseIndexKey32bit(OperandVector &Operands) {
72157260 return tryParseIndexKey (Operands, AMDGPUOperand::ImmTyIndexKey32bit);
72167261}
72177262
7263+ ParseStatus AMDGPUAsmParser::tryParseMatrixFMT (OperandVector &Operands,
7264+ StringRef Name,
7265+ AMDGPUOperand::ImmTy Type) {
7266+ return parseStringOrIntWithPrefix (Operands, Name,
7267+ {" MATRIX_FMT_FP8" , " MATRIX_FMT_BF8" ,
7268+ " MATRIX_FMT_FP6" , " MATRIX_FMT_BF6" ,
7269+ " MATRIX_FMT_FP4" },
7270+ Type);
7271+ }
7272+
7273+ ParseStatus AMDGPUAsmParser::parseMatrixAFMT (OperandVector &Operands) {
7274+ return tryParseMatrixFMT (Operands, " matrix_a_fmt" ,
7275+ AMDGPUOperand::ImmTyMatrixAFMT);
7276+ }
7277+
7278+ ParseStatus AMDGPUAsmParser::parseMatrixBFMT (OperandVector &Operands) {
7279+ return tryParseMatrixFMT (Operands, " matrix_b_fmt" ,
7280+ AMDGPUOperand::ImmTyMatrixBFMT);
7281+ }
7282+
72187283// dfmt and nfmt (in a tbuffer instruction) are parsed as one to allow their
72197284// values to live in a joint format operand in the MCInst encoding.
72207285ParseStatus AMDGPUAsmParser::parseDfmtNfmt (int64_t &Format) {
@@ -9316,6 +9381,20 @@ void AMDGPUAsmParser::cvtVOP3P(MCInst &Inst, const OperandVector &Operands,
93169381 DefaultVal);
93179382 }
93189383
9384+ int MatrixAFMTIdx =
9385+ AMDGPU::getNamedOperandIdx (Opc, AMDGPU::OpName::matrix_a_fmt);
9386+ if (MatrixAFMTIdx != -1 ) {
9387+ addOptionalImmOperand (Inst, Operands, OptIdx,
9388+ AMDGPUOperand::ImmTyMatrixAFMT, 0 );
9389+ }
9390+
9391+ int MatrixBFMTIdx =
9392+ AMDGPU::getNamedOperandIdx (Opc, AMDGPU::OpName::matrix_b_fmt);
9393+ if (MatrixBFMTIdx != -1 ) {
9394+ addOptionalImmOperand (Inst, Operands, OptIdx,
9395+ AMDGPUOperand::ImmTyMatrixBFMT, 0 );
9396+ }
9397+
93199398 if (AMDGPU::hasNamedOperand (Opc, AMDGPU::OpName::matrix_a_reuse))
93209399 addOptionalImmOperand (Inst, Operands, OptIdx,
93219400 AMDGPUOperand::ImmTyMatrixAReuse, 0 );
0 commit comments