Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions docs/programming-guide/chapter-3/fpsan.rst
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ What FpSan Preserves
FpSan preserves exact identities in the payload algebra selected by each
rewrite. The most important ones are:

- ring identities for add, subtract, multiply, FMA, and dot-like accumulation
- ring identities for add, subtract, unary negation, multiply, FMA, and
dot-like accumulation
- selected exponential identities for ``exp`` and ``exp2`` (see below for details)
- trigonometric identities for ``sin`` and ``cos``
- payload equality through casts, loads, stores, and copies
Expand Down Expand Up @@ -133,24 +134,28 @@ family?"
Common Arithmetic Ops
----------------------

Add, Sub, Mul
=============
Add, Sub, Neg, Mul
==================

Supported operations:

- ``x + y``
- ``x - y``
- ``-x``
- ``x * y``

Rewrite:

- add, subtract, or multiply the embedded payloads, then unembed the result
- add, subtract, negate, or multiply the embedded payloads, then unembed the
result

Exact preserved properties:

- ``x + 0 = x``
- ``x - 0 = x``
- ``x - x = 0``
- ``x + (-x) = 0``
- ``-(-x) = x``
- ``x * 1 = x``
- associativity and commutativity of add and mul
- distributivity of mul over add
Expand Down
27 changes: 23 additions & 4 deletions lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1316,6 +1316,24 @@ struct BinaryFloatToIntPattern : public OpRewritePattern<OpF> {
}
};

struct NegFOpPattern : public OpRewritePattern<arith::NegFOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(arith::NegFOp op,
PatternRewriter &rewriter) const override {
if (!isFloatLike(op.getType()))
return failure();

auto loc = op.getLoc();
auto inputI = embedToInt(rewriter, loc, op.getOperand());
auto zeroI = getIntConstantLike(rewriter, loc, inputI.getType(), 0);
auto resI = arith::SubIOp::create(rewriter, loc, zeroI, inputI);
auto resF = unembedToFloat(rewriter, loc, resI, op.getType());
rewriter.replaceOp(op, resF);
return success();
}
};

struct DivFOpPattern : public OpRewritePattern<arith::DivFOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(arith::DivFOp op,
Expand Down Expand Up @@ -2400,10 +2418,11 @@ class FpSanitizerPass
BinaryFloatToIntPattern<arith::MaximumFOp, arith::MaxSIOp>,
BinaryFloatToIntPattern<arith::MinNumFOp, arith::MinSIOp>,
BinaryFloatToIntPattern<arith::MaxNumFOp, arith::MaxSIOp>,
DivFOpPattern, PreciseDivFOpPattern, RemFOpPattern, FmaPattern,
ExpOpPattern, Exp2OpPattern, CosOpPattern, SinOpPattern,
ExtFOpPattern, TruncFOpPattern, FpToFpPattern, Fp4ToFpPattern,
DotPattern, DotScaledPattern>(&getContext());
NegFOpPattern, DivFOpPattern, PreciseDivFOpPattern,
RemFOpPattern, FmaPattern, ExpOpPattern, Exp2OpPattern,
CosOpPattern, SinOpPattern, ExtFOpPattern, TruncFOpPattern,
FpToFpPattern, Fp4ToFpPattern, DotPattern, DotScaledPattern>(
&getContext());
patterns.add<UnaryPattern<math::LogOp>>(&getContext(), UnaryOpId::Log);
patterns.add<UnaryPattern<math::Log2Op>>(&getContext(), UnaryOpId::Log2);
patterns.add<UnaryPattern<math::SqrtOp>>(&getContext(), UnaryOpId::Sqrt);
Expand Down
13 changes: 12 additions & 1 deletion python/test/gluon/test_fpsan.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,11 @@ def _expected_sub_i32(x_i32: np.ndarray, y_i32: np.ndarray) -> np.ndarray:
return _payload_u32_to_f32_bits_i32(x_u32 - y_u32)


def _expected_neg_i32(x_i32: np.ndarray) -> np.ndarray:
x_u32 = _mix_f32_bits_to_payload_u32(x_i32).astype(np.uint64)
return _payload_u32_to_f32_bits_i32(np.uint64(0) - x_u32)


def _expected_mul_i32(x_i32: np.ndarray, y_i32: np.ndarray) -> np.ndarray:
x_u32 = _mix_f32_bits_to_payload_u32(x_i32).astype(np.uint64)
y_u32 = _mix_f32_bits_to_payload_u32(y_i32).astype(np.uint64)
Expand Down Expand Up @@ -610,7 +615,10 @@ def _unary_math_kernel(x_ptr, out_ptr, n_elements, OP: gl.constexpr, BLOCK: gl.c
offs = pid * BLOCK + gl.arange(0, BLOCK, layout=layout)
mask = offs < n_elements
x = gl.load(x_ptr + offs, mask=mask, other=0.0)
z = getattr(gl, OP)(x)
if OP == "neg":
z = -x
else:
z = getattr(gl, OP)(x)
gl.store(out_ptr + offs, z, mask=mask)


Expand Down Expand Up @@ -711,6 +719,7 @@ def _cossin_identity_kernel(x_ptr, y_ptr, lhs_ptr, rhs_ptr, n_elements, MODE: gl
[
"exp",
"exp2",
"neg",
"log",
"log2",
"cos",
Expand Down Expand Up @@ -752,6 +761,8 @@ def test_unary_math_identity(device, op, fresh_knobs):
exp_bits = _expected_exp_i32(x_bits)
elif op == "exp2":
exp_bits = _expected_exp2_i32(x_bits)
elif op == "neg":
exp_bits = _expected_neg_i32(x_bits)
elif op == "cos":
exp_bits = _expected_cos_i32(x_bits)
elif op == "sin":
Expand Down
15 changes: 15 additions & 0 deletions test/TritonGPU/fpsan.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,21 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {

// -----

module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
// CHECK-LABEL: @neg_op
tt.func public @neg_op(%a: tensor<4xf32>) -> tensor<4xf32> {
// CHECK-DAG: %[[A:.*]] = tti.experimental_fpsan_embed %arg0 : (tensor<4xf32>) -> tensor<4xi32>
// CHECK-DAG: %[[ZERO:.*]] = arith.constant dense<0> : tensor<4xi32>
// CHECK: %[[NEG:.*]] = arith.subi %[[ZERO]], %[[A]] : tensor<4xi32>
// CHECK: %[[OUT:.*]] = tti.experimental_fpsan_unembed %[[NEG]] : (tensor<4xi32>) -> tensor<4xf32>
// CHECK-NOT: arith.negf
%neg = arith.negf %a : tensor<4xf32>
tt.return %neg : tensor<4xf32>
}
}

// -----

module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
// CHECK-LABEL: @chained_ops
tt.func public @chained_ops(%a: tensor<4xf32>, %b: tensor<4xf32>, %c: tensor<4xf32>) -> tensor<4xf32> {
Expand Down
Loading