diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index d7898d72e35d..5d28f5d99fdc 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -677,7 +677,7 @@ class ScaledBlockedToMMA : public mlir::OpRewritePattern { mlir::LogicalResult matchAndRewrite(triton::DotScaledOp dotOp, mlir::PatternRewriter &rewriter) const override { - if (computeCapability != 120) + if (computeCapability / 10 != 12) return failure(); auto numCTAs = lookupNumCTAs(rewriter); @@ -924,7 +924,7 @@ static bool mmav2SupportsFp8Operands(int computeCapability) { // although PTX instructions for mma v2 w/ fp8 operands exist for sm90 and // sm100, they are emulated as fp16 upcasts + fp16 HMMA in SASS. sm120 has // hardware support for fp8 operands w/ mmav2. - return computeCapability == 89 || computeCapability == 120; + return computeCapability == 89 || computeCapability / 10 == 12; } // promote operands of dot op if the existing combination is not natively