diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 00b5311b8415..4f491e7f9b49 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -2016,22 +2016,82 @@ def square(x: ArrayLike, /) -> Array: return lax.integer_pow(x, 2) -@implements(np.deg2rad, module='numpy') @partial(jit, inline=True) def deg2rad(x: ArrayLike, /) -> Array: + r"""Convert angles from degrees to radians. + + JAX implementation of :obj:`numpy.deg2rad`. + + The angle in degrees is converted to radians by: + + .. math:: + + deg2rad(x) = x * \frac{pi}{180} + + Args: + x: scalar or array. Specifies the angle in degrees. + + Returns: + An array containing the angles in radians. + + See also: + - :func:`jax.numpy.rad2deg` and :func:`jax.numpy.degrees`: Converts the angles + from radians to degrees. + - :func:`jax.numpy.radians`: Alias of ``deg2rad``. + + Examples: + >>> x = jnp.array([60, 90, 120, 180]) + >>> jnp.deg2rad(x) + Array([1.0471976, 1.5707964, 2.0943952, 3.1415927], dtype=float32) + >>> x * jnp.pi / 180 + Array([1.0471976, 1.5707964, 2.0943952, 3.1415927], dtype=float32, weak_type=True) + """ x, = promote_args_inexact("deg2rad", x) return lax.mul(x, _lax_const(x, np.pi / 180)) -@implements(np.rad2deg, module='numpy') @partial(jit, inline=True) def rad2deg(x: ArrayLike, /) -> Array: + r"""Convert angles from radians to degrees. + + JAX implementation of :obj:`numpy.rad2deg`. + + The angle in radians is converted to degrees by: + + .. math:: + + rad2deg(x) = x * \frac{180}{pi} + + Args: + x: scalar or array. Specifies the angle in radians. + + Returns: + An array containing the angles in degrees. + + See also: + - :func:`jax.numpy.deg2rad` and :func:`jax.numpy.radians`: Converts the angles + from degrees to radians. + - :func:`jax.numpy.degrees`: Alias of ``rad2deg``. + + Examples: + >>> pi = jnp.pi + >>> x = jnp.array([pi/4, pi/2, 2*pi/3]) + >>> jnp.rad2deg(x) + Array([ 45. , 90. , 120.00001], dtype=float32) + >>> x * 180 / pi + Array([ 45., 90., 120.], dtype=float32) + """ x, = promote_args_inexact("rad2deg", x) return lax.mul(x, _lax_const(x, 180 / np.pi)) -degrees = rad2deg -radians = deg2rad +def degrees(x: ArrayLike, /) -> Array: + """Alias of :func:`jax.numpy.rad2deg`""" + return rad2deg(x) + +def radians(x: ArrayLike, /) -> Array: + """Alias of :func:`jax.numpy.deg2rad`""" + return deg2rad(x) @implements(np.conjugate, module='numpy') diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index d3f9f2d615e2..70c9b503b895 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -6308,8 +6308,8 @@ def test_lax_numpy_docstrings(self): unimplemented = ['fromfile', 'fromiter'] aliases = ['abs', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atanh', 'atan2', - 'amax', 'amin', 'around', 'bitwise_right_shift', 'divide', 'pow', - 'round_'] + 'amax', 'amin', 'around', 'bitwise_right_shift', 'degrees', 'divide', + 'pow', 'radians', 'round_'] skip_args_check = ['vsplit', 'hsplit', 'dsplit', 'array_split'] for name in dir(jnp):