diff --git a/lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp b/lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp index c9068da03e12..40cc8eb1efc7 100644 --- a/lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp +++ b/lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp @@ -705,6 +705,70 @@ Value fpsanExp(PatternRewriter &rewriter, Location loc, Value input) { return fpsanExp2FromI32(rewriter, loc, scaledI, input.getType()); } +struct FpSanCosSin { + Value cos; + Value sin; +}; + +FpSanCosSin fpsanCosSinPayload(PatternRewriter &rewriter, Location loc, + Value xI) { + Type intTy = xI.getType(); + unsigned bitWidth = getIntBitwidth(intTy); + uint64_t mask = getLowBitsMask(bitWidth); + uint64_t rcp5 = invOddU64(5) & mask; + uint64_t aValue = (uint64_t{0} - ((uint64_t{3} * rcp5) & mask)) & mask; + uint64_t bValue = (uint64_t{4} * rcp5) & mask; + + auto zero = getUIntConstantLike(rewriter, loc, intTy, 0); + auto one = getUIntConstantLike(rewriter, loc, intTy, 1); + auto two = getUIntConstantLike(rewriter, loc, intTy, 2); + auto a = getUIntConstantLike(rewriter, loc, intTy, aValue); + auto b = getUIntConstantLike(rewriter, loc, intTy, bValue); + + Value c = one; + Value s = zero; + for (int bit = static_cast(bitWidth) - 1; bit >= 0; --bit) { + Value cc = arith::MulIOp::create(rewriter, loc, c, c); + Value ss = arith::MulIOp::create(rewriter, loc, s, s); + Value cDouble = arith::SubIOp::create(rewriter, loc, cc, ss); + Value cs = arith::MulIOp::create(rewriter, loc, c, s); + Value sDouble = arith::MulIOp::create(rewriter, loc, two, cs); + + Value ac = arith::MulIOp::create(rewriter, loc, a, cDouble); + Value bs = arith::MulIOp::create(rewriter, loc, b, sDouble); + Value cInc = arith::SubIOp::create(rewriter, loc, ac, bs); + Value as = arith::MulIOp::create(rewriter, loc, a, sDouble); + Value bc = arith::MulIOp::create(rewriter, loc, b, cDouble); + Value sInc = arith::AddIOp::create(rewriter, loc, as, bc); + + auto bitMask = + getUIntConstantLike(rewriter, loc, intTy, uint64_t{1} << bit); + auto masked = arith::AndIOp::create(rewriter, loc, xI, bitMask); + auto isZero = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, + masked, zero); + c = arith::SelectOp::create(rewriter, loc, isZero, cDouble, cInc); + s = arith::SelectOp::create(rewriter, loc, isZero, sDouble, sInc); + } + + return {c, s}; +} + +Value fpsanCos(PatternRewriter &rewriter, Location loc, Value input) { + if (!isa(getElementType(input.getType()))) + return Value(); + auto cosSin = + fpsanCosSinPayload(rewriter, loc, embedToInt(rewriter, loc, input)); + return unembedToFloat(rewriter, loc, cosSin.cos, input.getType()); +} + +Value fpsanSin(PatternRewriter &rewriter, Location loc, Value input) { + if (!isa(getElementType(input.getType()))) + return Value(); + auto cosSin = + fpsanCosSinPayload(rewriter, loc, embedToInt(rewriter, loc, input)); + return unembedToFloat(rewriter, loc, cosSin.sin, input.getType()); +} + bool isIntLike(Type ty) { return isa(getElementType(ty)); } bool isNumericLike(Type ty) { @@ -1316,6 +1380,36 @@ struct Exp2OpPattern : public OpRewritePattern { } }; +struct CosOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(math::CosOp op, + PatternRewriter &rewriter) const override { + if (!isFloatLike(op.getType())) + return failure(); + Value result = fpsanCos(rewriter, op.getLoc(), op.getOperand()); + if (!result) + result = fpsanUnaryTagged(rewriter, op.getLoc(), op.getOperand(), + UnaryOpId::Cos); + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct SinOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(math::SinOp op, + PatternRewriter &rewriter) const override { + if (!isFloatLike(op.getType())) + return failure(); + Value result = fpsanSin(rewriter, op.getLoc(), op.getOperand()); + if (!result) + result = fpsanUnaryTagged(rewriter, op.getLoc(), op.getOperand(), + UnaryOpId::Sin); + rewriter.replaceOp(op, result); + return success(); + } +}; + struct ExtFOpPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(arith::ExtFOp op, @@ -2092,13 +2186,11 @@ class FpSanitizerPass BinaryFloatToIntPattern, BinaryFloatToIntPattern, DivFOpPattern, PreciseDivFOpPattern, RemFOpPattern, FmaPattern, - ExpOpPattern, Exp2OpPattern, ExtFOpPattern, TruncFOpPattern, - FpToFpPattern, Fp4ToFpPattern, DotPattern, DotScaledPattern>( - &getContext()); + ExpOpPattern, Exp2OpPattern, CosOpPattern, SinOpPattern, + ExtFOpPattern, TruncFOpPattern, FpToFpPattern, Fp4ToFpPattern, + DotPattern, DotScaledPattern>(&getContext()); patterns.add>(&getContext(), UnaryOpId::Log); patterns.add>(&getContext(), UnaryOpId::Log2); - patterns.add>(&getContext(), UnaryOpId::Cos); - patterns.add>(&getContext(), UnaryOpId::Sin); patterns.add>(&getContext(), UnaryOpId::Sqrt); patterns.add>(&getContext(), UnaryOpId::Rsqrt); patterns.add>(&getContext(), UnaryOpId::Erf); diff --git a/python/test/gluon/test_fpsan.py b/python/test/gluon/test_fpsan.py index 3b73ab30a99d..29da2269fceb 100644 --- a/python/test/gluon/test_fpsan.py +++ b/python/test/gluon/test_fpsan.py @@ -237,6 +237,36 @@ def _expected_exp_i32(x_i32: np.ndarray) -> np.ndarray: return _expected_exp2_i32(_unmix_payload_u32_to_f32_bits_i32(scaled)) +def _expected_cossin_payload_u32(x_u32: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + mask = np.uint64(0xFFFFFFFF) + rcp5 = int(_inv_odd_u64(np.uint64(5)) & mask) + a = np.uint64((-3 * rcp5) & 0xFFFFFFFF) + b = np.uint64((4 * rcp5) & 0xFFFFFFFF) + x = x_u32.astype(np.uint64) + c = np.ones_like(x, dtype=np.uint64) + s = np.zeros_like(x, dtype=np.uint64) + with np.errstate(over="ignore"): + for i in range(32): + c_double = (c * c - s * s) & mask + s_double = (np.uint64(2) * c * s) & mask + c_inc = (a * c_double - b * s_double) & mask + s_inc = (a * s_double + b * c_double) & mask + inc = (x & np.uint64(1 << (31 - i))) != 0 + c = np.where(inc, c_inc, c_double) & mask + s = np.where(inc, s_inc, s_double) & mask + return c.astype(np.uint32), s.astype(np.uint32) + + +def _expected_cos_i32(x_i32: np.ndarray) -> np.ndarray: + c, _ = _expected_cossin_payload_u32(_mix_f32_bits_to_payload_u32(x_i32)) + return _unmix_payload_u32_to_f32_bits_i32(c) + + +def _expected_sin_i32(x_i32: np.ndarray) -> np.ndarray: + _, s = _expected_cossin_payload_u32(_mix_f32_bits_to_payload_u32(x_i32)) + return _unmix_payload_u32_to_f32_bits_i32(s) + + def _expected_unary_tag_i32(x_i32: np.ndarray, op: str) -> np.ndarray: # Keep this mapping in sync with UnaryOpId in FpSanitizer.cpp. out_u32 = _expected_unary_tag_payload_u32(_mix_f32_bits_to_payload_u32(x_i32), op) @@ -619,6 +649,43 @@ def _exp_inverse_identity_kernel(x_ptr, out_ptr, n_elements, MODE: gl.constexpr, gl.store(out_ptr + offs, z, mask=mask) +@gluon.jit +def _cossin_identity_kernel(x_ptr, y_ptr, lhs_ptr, rhs_ptr, n_elements, MODE: gl.constexpr, BLOCK: gl.constexpr, + THREADS_PER_WARP: gl.constexpr): + pid = gl.program_id(0) + layout: gl.constexpr = gl.BlockedLayout(size_per_thread=[2], threads_per_warp=[THREADS_PER_WARP], warps_per_cta=[4], + order=[0]) + offs = pid * BLOCK + gl.arange(0, BLOCK, layout=layout) + mask = offs < n_elements + x = gl.load(x_ptr + offs, mask=mask, other=0.0) + y = gl.load(y_ptr + offs, mask=mask, other=0.0) + sx = gl.sin(x) + sy = gl.sin(y) + cx = gl.cos(x) + cy = gl.cos(y) + + if MODE == "sin_add": + lhs = gl.sin(x + y) + rhs = sx * cy + cx * sy + elif MODE == "sin_sub": + lhs = gl.sin(x - y) + rhs = sx * cy - cx * sy + elif MODE == "cos_add": + lhs = gl.cos(x + y) + rhs = cx * cy - sx * sy + elif MODE == "cos_sub": + lhs = gl.cos(x - y) + rhs = cx * cy + sx * sy + elif MODE == "unit": + lhs = cx * cx + sx * sx + rhs = x * 0.0 + 1.0 + else: + gl.static_assert(False, "unsupported MODE") + + gl.store(lhs_ptr + offs, lhs, mask=mask) + gl.store(rhs_ptr + offs, rhs, mask=mask) + + @pytest.mark.parametrize( "op", [ @@ -665,6 +732,10 @@ def test_unary_math_identity(device, op, fresh_knobs): exp_bits = _expected_exp_i32(x_bits) elif op == "exp2": exp_bits = _expected_exp2_i32(x_bits) + elif op == "cos": + exp_bits = _expected_cos_i32(x_bits) + elif op == "sin": + exp_bits = _expected_sin_i32(x_bits) else: exp_bits = _expected_unary_tag_i32(x_bits, op) _assert_payload_equal(out, exp_bits) @@ -753,6 +824,34 @@ def test_exp_neg_reciprocal_identity(device, fresh_knobs): _assert_payload_equal(out_neg, out_recip) +@pytest.mark.parametrize("mode", ["sin_add", "sin_sub", "cos_add", "cos_sub", "unit"]) +def test_cossin_angle_identities(device, mode, fresh_knobs): + _require_cuda_backend(device) + + fresh_knobs.compilation.instrumentation_mode = "fpsan" + + n_elements = 1024 + BLOCK = 256 + + g = torch.Generator(device="cuda") + g.manual_seed(5) + x = torch.randint(-(2**31), 2**31 - 1, (n_elements, ), dtype=torch.int32, device="cuda", generator=g) + y = torch.randint(-(2**31), 2**31 - 1, (n_elements, ), dtype=torch.int32, device="cuda", generator=g) + lhs = torch.empty((n_elements, ), dtype=torch.int32, device="cuda") + rhs = torch.empty((n_elements, ), dtype=torch.int32, device="cuda") + + xw = triton.TensorWrapper(x, dtype=torch.float32) + yw = triton.TensorWrapper(y, dtype=torch.float32) + lhsw = triton.TensorWrapper(lhs, dtype=torch.float32) + rhsw = triton.TensorWrapper(rhs, dtype=torch.float32) + + grid = (triton.cdiv(n_elements, BLOCK), ) + _cossin_identity_kernel[grid](xw, yw, lhsw, rhsw, n_elements, MODE=mode, BLOCK=BLOCK, + THREADS_PER_WARP=THREADS_PER_WARP) + + _assert_payload_equal(lhs, rhs) + + @gluon.jit def _extern_unary_math_kernel(x_ptr, out_ptr, n_elements, OP: gl.constexpr, BLOCK: gl.constexpr, THREADS_PER_WARP: gl.constexpr):