Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 97 additions & 5 deletions lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,70 @@ Value fpsanExp(PatternRewriter &rewriter, Location loc, Value input) {
return fpsanExp2FromI32(rewriter, loc, scaledI, input.getType());
}

struct FpSanCosSin {
Value cos;
Value sin;
};

FpSanCosSin fpsanCosSinPayload(PatternRewriter &rewriter, Location loc,
Value xI) {
Type intTy = xI.getType();
unsigned bitWidth = getIntBitwidth(intTy);
uint64_t mask = getLowBitsMask(bitWidth);
uint64_t rcp5 = invOddU64(5) & mask;
uint64_t aValue = (uint64_t{0} - ((uint64_t{3} * rcp5) & mask)) & mask;
uint64_t bValue = (uint64_t{4} * rcp5) & mask;

auto zero = getUIntConstantLike(rewriter, loc, intTy, 0);
auto one = getUIntConstantLike(rewriter, loc, intTy, 1);
auto two = getUIntConstantLike(rewriter, loc, intTy, 2);
auto a = getUIntConstantLike(rewriter, loc, intTy, aValue);
auto b = getUIntConstantLike(rewriter, loc, intTy, bValue);

Value c = one;
Value s = zero;
for (int bit = static_cast<int>(bitWidth) - 1; bit >= 0; --bit) {
Value cc = arith::MulIOp::create(rewriter, loc, c, c);
Value ss = arith::MulIOp::create(rewriter, loc, s, s);
Value cDouble = arith::SubIOp::create(rewriter, loc, cc, ss);
Value cs = arith::MulIOp::create(rewriter, loc, c, s);
Value sDouble = arith::MulIOp::create(rewriter, loc, two, cs);

Value ac = arith::MulIOp::create(rewriter, loc, a, cDouble);
Value bs = arith::MulIOp::create(rewriter, loc, b, sDouble);
Value cInc = arith::SubIOp::create(rewriter, loc, ac, bs);
Value as = arith::MulIOp::create(rewriter, loc, a, sDouble);
Value bc = arith::MulIOp::create(rewriter, loc, b, cDouble);
Value sInc = arith::AddIOp::create(rewriter, loc, as, bc);

auto bitMask =
getUIntConstantLike(rewriter, loc, intTy, uint64_t{1} << bit);
auto masked = arith::AndIOp::create(rewriter, loc, xI, bitMask);
auto isZero = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
masked, zero);
c = arith::SelectOp::create(rewriter, loc, isZero, cDouble, cInc);
s = arith::SelectOp::create(rewriter, loc, isZero, sDouble, sInc);
}

return {c, s};
}

Value fpsanCos(PatternRewriter &rewriter, Location loc, Value input) {
if (!isa<FloatType>(getElementType(input.getType())))
return Value();
auto cosSin =
fpsanCosSinPayload(rewriter, loc, embedToInt(rewriter, loc, input));
return unembedToFloat(rewriter, loc, cosSin.cos, input.getType());
}

Value fpsanSin(PatternRewriter &rewriter, Location loc, Value input) {
if (!isa<FloatType>(getElementType(input.getType())))
return Value();
auto cosSin =
fpsanCosSinPayload(rewriter, loc, embedToInt(rewriter, loc, input));
return unembedToFloat(rewriter, loc, cosSin.sin, input.getType());
}

bool isIntLike(Type ty) { return isa<IntegerType>(getElementType(ty)); }

bool isNumericLike(Type ty) {
Expand Down Expand Up @@ -1316,6 +1380,36 @@ struct Exp2OpPattern : public OpRewritePattern<math::Exp2Op> {
}
};

struct CosOpPattern : public OpRewritePattern<math::CosOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(math::CosOp op,
PatternRewriter &rewriter) const override {
if (!isFloatLike(op.getType()))
return failure();
Value result = fpsanCos(rewriter, op.getLoc(), op.getOperand());
if (!result)
result = fpsanUnaryTagged(rewriter, op.getLoc(), op.getOperand(),
UnaryOpId::Cos);
rewriter.replaceOp(op, result);
return success();
}
};

struct SinOpPattern : public OpRewritePattern<math::SinOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(math::SinOp op,
PatternRewriter &rewriter) const override {
if (!isFloatLike(op.getType()))
return failure();
Value result = fpsanSin(rewriter, op.getLoc(), op.getOperand());
if (!result)
result = fpsanUnaryTagged(rewriter, op.getLoc(), op.getOperand(),
UnaryOpId::Sin);
rewriter.replaceOp(op, result);
return success();
}
};

struct ExtFOpPattern : public OpRewritePattern<arith::ExtFOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(arith::ExtFOp op,
Expand Down Expand Up @@ -2092,13 +2186,11 @@ class FpSanitizerPass
BinaryFloatToIntPattern<arith::SubFOp, arith::SubIOp>,
BinaryFloatToIntPattern<arith::MulFOp, arith::MulIOp>,
DivFOpPattern, PreciseDivFOpPattern, RemFOpPattern, FmaPattern,
ExpOpPattern, Exp2OpPattern, ExtFOpPattern, TruncFOpPattern,
FpToFpPattern, Fp4ToFpPattern, DotPattern, DotScaledPattern>(
&getContext());
ExpOpPattern, Exp2OpPattern, CosOpPattern, SinOpPattern,
ExtFOpPattern, TruncFOpPattern, FpToFpPattern, Fp4ToFpPattern,
DotPattern, DotScaledPattern>(&getContext());
patterns.add<UnaryPattern<math::LogOp>>(&getContext(), UnaryOpId::Log);
patterns.add<UnaryPattern<math::Log2Op>>(&getContext(), UnaryOpId::Log2);
patterns.add<UnaryPattern<math::CosOp>>(&getContext(), UnaryOpId::Cos);
patterns.add<UnaryPattern<math::SinOp>>(&getContext(), UnaryOpId::Sin);
patterns.add<UnaryPattern<math::SqrtOp>>(&getContext(), UnaryOpId::Sqrt);
patterns.add<UnaryPattern<math::RsqrtOp>>(&getContext(), UnaryOpId::Rsqrt);
patterns.add<UnaryPattern<math::ErfOp>>(&getContext(), UnaryOpId::Erf);
Expand Down
99 changes: 99 additions & 0 deletions python/test/gluon/test_fpsan.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,36 @@ def _expected_exp_i32(x_i32: np.ndarray) -> np.ndarray:
return _expected_exp2_i32(_unmix_payload_u32_to_f32_bits_i32(scaled))


def _expected_cossin_payload_u32(x_u32: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
mask = np.uint64(0xFFFFFFFF)
rcp5 = int(_inv_odd_u64(np.uint64(5)) & mask)
a = np.uint64((-3 * rcp5) & 0xFFFFFFFF)
b = np.uint64((4 * rcp5) & 0xFFFFFFFF)
x = x_u32.astype(np.uint64)
c = np.ones_like(x, dtype=np.uint64)
s = np.zeros_like(x, dtype=np.uint64)
with np.errstate(over="ignore"):
for i in range(32):
c_double = (c * c - s * s) & mask
s_double = (np.uint64(2) * c * s) & mask
c_inc = (a * c_double - b * s_double) & mask
s_inc = (a * s_double + b * c_double) & mask
inc = (x & np.uint64(1 << (31 - i))) != 0
c = np.where(inc, c_inc, c_double) & mask
s = np.where(inc, s_inc, s_double) & mask
return c.astype(np.uint32), s.astype(np.uint32)


def _expected_cos_i32(x_i32: np.ndarray) -> np.ndarray:
c, _ = _expected_cossin_payload_u32(_mix_f32_bits_to_payload_u32(x_i32))
return _unmix_payload_u32_to_f32_bits_i32(c)


def _expected_sin_i32(x_i32: np.ndarray) -> np.ndarray:
_, s = _expected_cossin_payload_u32(_mix_f32_bits_to_payload_u32(x_i32))
return _unmix_payload_u32_to_f32_bits_i32(s)


def _expected_unary_tag_i32(x_i32: np.ndarray, op: str) -> np.ndarray:
# Keep this mapping in sync with UnaryOpId in FpSanitizer.cpp.
out_u32 = _expected_unary_tag_payload_u32(_mix_f32_bits_to_payload_u32(x_i32), op)
Expand Down Expand Up @@ -619,6 +649,43 @@ def _exp_inverse_identity_kernel(x_ptr, out_ptr, n_elements, MODE: gl.constexpr,
gl.store(out_ptr + offs, z, mask=mask)


@gluon.jit
def _cossin_identity_kernel(x_ptr, y_ptr, lhs_ptr, rhs_ptr, n_elements, MODE: gl.constexpr, BLOCK: gl.constexpr,
THREADS_PER_WARP: gl.constexpr):
pid = gl.program_id(0)
layout: gl.constexpr = gl.BlockedLayout(size_per_thread=[2], threads_per_warp=[THREADS_PER_WARP], warps_per_cta=[4],
order=[0])
offs = pid * BLOCK + gl.arange(0, BLOCK, layout=layout)
mask = offs < n_elements
x = gl.load(x_ptr + offs, mask=mask, other=0.0)
y = gl.load(y_ptr + offs, mask=mask, other=0.0)
sx = gl.sin(x)
sy = gl.sin(y)
cx = gl.cos(x)
cy = gl.cos(y)

if MODE == "sin_add":
lhs = gl.sin(x + y)
rhs = sx * cy + cx * sy
elif MODE == "sin_sub":
lhs = gl.sin(x - y)
rhs = sx * cy - cx * sy
elif MODE == "cos_add":
lhs = gl.cos(x + y)
rhs = cx * cy - sx * sy
elif MODE == "cos_sub":
lhs = gl.cos(x - y)
rhs = cx * cy + sx * sy
elif MODE == "unit":
lhs = cx * cx + sx * sx
rhs = x * 0.0 + 1.0
else:
gl.static_assert(False, "unsupported MODE")

gl.store(lhs_ptr + offs, lhs, mask=mask)
gl.store(rhs_ptr + offs, rhs, mask=mask)


@pytest.mark.parametrize(
"op",
[
Expand Down Expand Up @@ -665,6 +732,10 @@ def test_unary_math_identity(device, op, fresh_knobs):
exp_bits = _expected_exp_i32(x_bits)
elif op == "exp2":
exp_bits = _expected_exp2_i32(x_bits)
elif op == "cos":
exp_bits = _expected_cos_i32(x_bits)
elif op == "sin":
exp_bits = _expected_sin_i32(x_bits)
else:
exp_bits = _expected_unary_tag_i32(x_bits, op)
_assert_payload_equal(out, exp_bits)
Expand Down Expand Up @@ -753,6 +824,34 @@ def test_exp_neg_reciprocal_identity(device, fresh_knobs):
_assert_payload_equal(out_neg, out_recip)


@pytest.mark.parametrize("mode", ["sin_add", "sin_sub", "cos_add", "cos_sub", "unit"])
def test_cossin_angle_identities(device, mode, fresh_knobs):
_require_cuda_backend(device)

fresh_knobs.compilation.instrumentation_mode = "fpsan"

n_elements = 1024
BLOCK = 256

g = torch.Generator(device="cuda")
g.manual_seed(5)
x = torch.randint(-(2**31), 2**31 - 1, (n_elements, ), dtype=torch.int32, device="cuda", generator=g)
y = torch.randint(-(2**31), 2**31 - 1, (n_elements, ), dtype=torch.int32, device="cuda", generator=g)
lhs = torch.empty((n_elements, ), dtype=torch.int32, device="cuda")
rhs = torch.empty((n_elements, ), dtype=torch.int32, device="cuda")

xw = triton.TensorWrapper(x, dtype=torch.float32)
yw = triton.TensorWrapper(y, dtype=torch.float32)
lhsw = triton.TensorWrapper(lhs, dtype=torch.float32)
rhsw = triton.TensorWrapper(rhs, dtype=torch.float32)

grid = (triton.cdiv(n_elements, BLOCK), )
_cossin_identity_kernel[grid](xw, yw, lhsw, rhsw, n_elements, MODE=mode, BLOCK=BLOCK,
THREADS_PER_WARP=THREADS_PER_WARP)

_assert_payload_equal(lhs, rhs)


@gluon.jit
def _extern_unary_math_kernel(x_ptr, out_ptr, n_elements, OP: gl.constexpr, BLOCK: gl.constexpr,
THREADS_PER_WARP: gl.constexpr):
Expand Down
Loading