Skip to content

Commit

Permalink
Merge pull request #23788 from rajasekharporeddy:testbranch2
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 676891040
  • Loading branch information
Google-ML-Automation committed Sep 20, 2024
2 parents e2cdb79 + 0c87a23 commit 82b0e0e
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 6 deletions.
68 changes: 64 additions & 4 deletions jax/_src/numpy/ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2016,22 +2016,82 @@ def square(x: ArrayLike, /) -> Array:
return lax.integer_pow(x, 2)


@implements(np.deg2rad, module='numpy')
@partial(jit, inline=True)
def deg2rad(x: ArrayLike, /) -> Array:
r"""Convert angles from degrees to radians.
JAX implementation of :obj:`numpy.deg2rad`.
The angle in degrees is converted to radians by:
.. math::
deg2rad(x) = x * \frac{pi}{180}
Args:
x: scalar or array. Specifies the angle in degrees.
Returns:
An array containing the angles in radians.
See also:
- :func:`jax.numpy.rad2deg` and :func:`jax.numpy.degrees`: Converts the angles
from radians to degrees.
- :func:`jax.numpy.radians`: Alias of ``deg2rad``.
Examples:
>>> x = jnp.array([60, 90, 120, 180])
>>> jnp.deg2rad(x)
Array([1.0471976, 1.5707964, 2.0943952, 3.1415927], dtype=float32)
>>> x * jnp.pi / 180
Array([1.0471976, 1.5707964, 2.0943952, 3.1415927], dtype=float32, weak_type=True)
"""
x, = promote_args_inexact("deg2rad", x)
return lax.mul(x, _lax_const(x, np.pi / 180))


@implements(np.rad2deg, module='numpy')
@partial(jit, inline=True)
def rad2deg(x: ArrayLike, /) -> Array:
r"""Convert angles from radians to degrees.
JAX implementation of :obj:`numpy.rad2deg`.
The angle in radians is converted to degrees by:
.. math::
rad2deg(x) = x * \frac{180}{pi}
Args:
x: scalar or array. Specifies the angle in radians.
Returns:
An array containing the angles in degrees.
See also:
- :func:`jax.numpy.deg2rad` and :func:`jax.numpy.radians`: Converts the angles
from degrees to radians.
- :func:`jax.numpy.degrees`: Alias of ``rad2deg``.
Examples:
>>> pi = jnp.pi
>>> x = jnp.array([pi/4, pi/2, 2*pi/3])
>>> jnp.rad2deg(x)
Array([ 45. , 90. , 120.00001], dtype=float32)
>>> x * 180 / pi
Array([ 45., 90., 120.], dtype=float32)
"""
x, = promote_args_inexact("rad2deg", x)
return lax.mul(x, _lax_const(x, 180 / np.pi))


degrees = rad2deg
radians = deg2rad
def degrees(x: ArrayLike, /) -> Array:
"""Alias of :func:`jax.numpy.rad2deg`"""
return rad2deg(x)

def radians(x: ArrayLike, /) -> Array:
"""Alias of :func:`jax.numpy.deg2rad`"""
return deg2rad(x)


@implements(np.conjugate, module='numpy')
Expand Down
4 changes: 2 additions & 2 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6308,8 +6308,8 @@ def test_lax_numpy_docstrings(self):

unimplemented = ['fromfile', 'fromiter']
aliases = ['abs', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atanh', 'atan2',
'amax', 'amin', 'around', 'bitwise_right_shift', 'divide', 'pow',
'round_']
'amax', 'amin', 'around', 'bitwise_right_shift', 'degrees', 'divide',
'pow', 'radians', 'round_']
skip_args_check = ['vsplit', 'hsplit', 'dsplit', 'array_split']

for name in dir(jnp):
Expand Down

0 comments on commit 82b0e0e

Please sign in to comment.