Skip to content

Commit

Permalink
Re-land jax-ml#23261 with appropriate compatibility checks.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 676092618
  • Loading branch information
dfm authored and Google-ML-Automation committed Sep 18, 2024
1 parent b164d67 commit dbc03cf
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
12 changes: 11 additions & 1 deletion jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2012,9 +2012,19 @@ def _cos_lowering(ctx, x):
def _tan_impl(x):
return div(sin(x), cos(x))

def _tan_lowering(ctx, x):
# TODO(b/368011034): Remove after jaxlib 0.4.34 release. In 0.4.33, this
# lowering is supported, but export doesn't target a sufficiently up-to-date
# StableHLO version, and the compatibility updates from
# https://github.com/openxla/xla/pull/16649 aren't included in the 0.4.33
# release.
if ctx.is_forward_compat():
return _nary_lower_hlo(chlo.tan, ctx, x)
return _nary_lower_hlo(hlo.tan, ctx, x)

tan_p = standard_unop(_float | _complex, 'tan')
ad.defjvp2(tan_p, lambda g, ans, x: mul(g, _const(x, 1) + square(ans)))
mlir.register_lowering(tan_p, partial(_nary_lower_hlo, chlo.tan))
mlir.register_lowering(tan_p, _tan_lowering)

def asin_impl(x):
if dtypes.issubdtype(_dtype(x), np.complexfloating):
Expand Down
2 changes: 1 addition & 1 deletion tests/filecheck/math.filecheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ def integer_pow(x): return lax.integer_pow(x, 3)
print_ir(jnp.bfloat16(0))(lax.sqrt)

# CHECK-LABEL: TEST: tan float16[]
# CHECK: chlo.tan
# CHECK: hlo.tan
# CHECK-SAME: tensor<f16>
print_ir(np.float16(0))(lax.tan)

Expand Down

0 comments on commit dbc03cf

Please sign in to comment.