From 317ae065bff35e7d37a5abbef048ffb98103467c Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Wed, 13 May 2026 21:45:32 +0100 Subject: [PATCH 1/2] [fpsan] Support arith.negf Since we now generate arith.negf, it needs to be supported by fpsan. --- docs/programming-guide/chapter-3/fpsan.rst | 13 ++++++--- .../Transforms/FpSanitizer.cpp | 27 ++++++++++++++++--- python/test/gluon/test_fpsan.py | 13 ++++++++- test/TritonGPU/fpsan.mlir | 15 +++++++++++ 4 files changed, 59 insertions(+), 9 deletions(-) diff --git a/docs/programming-guide/chapter-3/fpsan.rst b/docs/programming-guide/chapter-3/fpsan.rst index b5b2ca4712fc..17c01bd658d0 100644 --- a/docs/programming-guide/chapter-3/fpsan.rst +++ b/docs/programming-guide/chapter-3/fpsan.rst @@ -96,7 +96,8 @@ What FpSan Preserves FpSan preserves exact identities in the payload algebra selected by each rewrite. The most important ones are: -- ring identities for add, subtract, multiply, FMA, and dot-like accumulation +- ring identities for add, subtract, unary negation, multiply, FMA, and + dot-like accumulation - selected exponential identities for ``exp`` and ``exp2`` (see below for details) - trigonometric identities for ``sin`` and ``cos`` - payload equality through casts, loads, stores, and copies @@ -133,24 +134,28 @@ family?" Common Arithmetic Ops ---------------------- -Add, Sub, Mul -============= +Add, Sub, Neg, Mul +================== Supported operations: - ``x + y`` - ``x - y`` +- ``-x`` - ``x * y`` Rewrite: -- add, subtract, or multiply the embedded payloads, then unembed the result +- add, subtract, negate, or multiply the embedded payloads, then unembed the + result Exact preserved properties: - ``x + 0 = x`` - ``x - 0 = x`` - ``x - x = 0`` +- ``x + (-x) = 0`` +- ``-(-x) = x`` - ``x * 1 = x`` - associativity and commutativity of add and mul - distributivity of mul over add diff --git a/lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp b/lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp index c4df44267fcd..29e918b6739b 100644 --- a/lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp +++ b/lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp @@ -1316,6 +1316,24 @@ struct BinaryFloatToIntPattern : public OpRewritePattern { } }; +struct NegFOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::NegFOp op, + PatternRewriter &rewriter) const override { + if (!isFloatLike(op.getType())) + return failure(); + + auto loc = op.getLoc(); + auto inputI = embedToInt(rewriter, loc, op.getOperand()); + auto zeroI = getIntConstantLike(rewriter, loc, inputI.getType(), 0); + auto resI = arith::SubIOp::create(rewriter, loc, zeroI, inputI); + auto resF = unembedToFloat(rewriter, loc, resI, op.getType()); + rewriter.replaceOp(op, resF); + return success(); + } +}; + struct DivFOpPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(arith::DivFOp op, @@ -2400,10 +2418,11 @@ class FpSanitizerPass BinaryFloatToIntPattern, BinaryFloatToIntPattern, BinaryFloatToIntPattern, - DivFOpPattern, PreciseDivFOpPattern, RemFOpPattern, FmaPattern, - ExpOpPattern, Exp2OpPattern, CosOpPattern, SinOpPattern, - ExtFOpPattern, TruncFOpPattern, FpToFpPattern, Fp4ToFpPattern, - DotPattern, DotScaledPattern>(&getContext()); + NegFOpPattern, DivFOpPattern, PreciseDivFOpPattern, + RemFOpPattern, FmaPattern, ExpOpPattern, Exp2OpPattern, + CosOpPattern, SinOpPattern, ExtFOpPattern, TruncFOpPattern, + FpToFpPattern, Fp4ToFpPattern, DotPattern, DotScaledPattern>( + &getContext()); patterns.add>(&getContext(), UnaryOpId::Log); patterns.add>(&getContext(), UnaryOpId::Log2); patterns.add>(&getContext(), UnaryOpId::Sqrt); diff --git a/python/test/gluon/test_fpsan.py b/python/test/gluon/test_fpsan.py index f00c6e3e6abe..ed296c145921 100644 --- a/python/test/gluon/test_fpsan.py +++ b/python/test/gluon/test_fpsan.py @@ -146,6 +146,11 @@ def _expected_sub_i32(x_i32: np.ndarray, y_i32: np.ndarray) -> np.ndarray: return _payload_u32_to_f32_bits_i32(x_u32 - y_u32) +def _expected_neg_i32(x_i32: np.ndarray) -> np.ndarray: + x_u32 = _mix_f32_bits_to_payload_u32(x_i32).astype(np.uint64) + return _payload_u32_to_f32_bits_i32(np.uint64(0) - x_u32) + + def _expected_mul_i32(x_i32: np.ndarray, y_i32: np.ndarray) -> np.ndarray: x_u32 = _mix_f32_bits_to_payload_u32(x_i32).astype(np.uint64) y_u32 = _mix_f32_bits_to_payload_u32(y_i32).astype(np.uint64) @@ -610,7 +615,10 @@ def _unary_math_kernel(x_ptr, out_ptr, n_elements, OP: gl.constexpr, BLOCK: gl.c offs = pid * BLOCK + gl.arange(0, BLOCK, layout=layout) mask = offs < n_elements x = gl.load(x_ptr + offs, mask=mask, other=0.0) - z = getattr(gl, OP)(x) + if OP == "neg": + z = -x + else: + z = getattr(gl, OP)(x) gl.store(out_ptr + offs, z, mask=mask) @@ -711,6 +719,7 @@ def _cossin_identity_kernel(x_ptr, y_ptr, lhs_ptr, rhs_ptr, n_elements, MODE: gl [ "exp", "exp2", + "neg", "log", "log2", "cos", @@ -752,6 +761,8 @@ def test_unary_math_identity(device, op, fresh_knobs): exp_bits = _expected_exp_i32(x_bits) elif op == "exp2": exp_bits = _expected_exp2_i32(x_bits) + elif op == "neg": + exp_bits = _expected_neg_i32(x_bits) elif op == "cos": exp_bits = _expected_cos_i32(x_bits) elif op == "sin": diff --git a/test/TritonGPU/fpsan.mlir b/test/TritonGPU/fpsan.mlir index 36fc6420609e..ef3dfa6c415d 100644 --- a/test/TritonGPU/fpsan.mlir +++ b/test/TritonGPU/fpsan.mlir @@ -139,6 +139,21 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // ----- +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: @neg_op + tt.func public @neg_op(%a: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: %[[A:.*]] = tti.experimental_fpsan_embed %arg0 : (tensor<4xf32>) -> tensor<4xi32> + // CHECK: %[[ZERO:.*]] = arith.constant dense<0> : tensor<4xi32> + // CHECK: %[[NEG:.*]] = arith.subi %[[ZERO]], %[[A]] : tensor<4xi32> + // CHECK: %[[OUT:.*]] = tti.experimental_fpsan_unembed %[[NEG]] : (tensor<4xi32>) -> tensor<4xf32> + // CHECK-NOT: arith.negf + %neg = arith.negf %a : tensor<4xf32> + tt.return %neg : tensor<4xf32> + } +} + +// ----- + module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: @chained_ops tt.func public @chained_ops(%a: tensor<4xf32>, %b: tensor<4xf32>, %c: tensor<4xf32>) -> tensor<4xf32> { From 2fc27b287db3bff93d22b271ccd512826039d527 Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Wed, 13 May 2026 21:58:58 +0100 Subject: [PATCH 2/2] check-dag --- test/TritonGPU/fpsan.mlir | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/TritonGPU/fpsan.mlir b/test/TritonGPU/fpsan.mlir index ef3dfa6c415d..3a0b3ad00245 100644 --- a/test/TritonGPU/fpsan.mlir +++ b/test/TritonGPU/fpsan.mlir @@ -142,8 +142,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: @neg_op tt.func public @neg_op(%a: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: %[[A:.*]] = tti.experimental_fpsan_embed %arg0 : (tensor<4xf32>) -> tensor<4xi32> - // CHECK: %[[ZERO:.*]] = arith.constant dense<0> : tensor<4xi32> + // CHECK-DAG: %[[A:.*]] = tti.experimental_fpsan_embed %arg0 : (tensor<4xf32>) -> tensor<4xi32> + // CHECK-DAG: %[[ZERO:.*]] = arith.constant dense<0> : tensor<4xi32> // CHECK: %[[NEG:.*]] = arith.subi %[[ZERO]], %[[A]] : tensor<4xi32> // CHECK: %[[OUT:.*]] = tti.experimental_fpsan_unembed %[[NEG]] : (tensor<4xi32>) -> tensor<4xf32> // CHECK-NOT: arith.negf