From c55cd37545361af4d41d3c3ec65c06a4e1962b95 Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Sat, 21 Sep 2024 09:34:43 +0530 Subject: [PATCH] Improve docs for jax.numpy: conjugate, conj, imag and real --- jax/_src/numpy/ufuncs.py | 82 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 78 insertions(+), 4 deletions(-) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 857ed8668d59..1eacf497169c 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -2163,24 +2163,98 @@ def radians(x: ArrayLike, /) -> Array: return deg2rad(x) -@implements(np.conjugate, module='numpy') @partial(jit, inline=True) def conjugate(x: ArrayLike, /) -> Array: + """Return element-wise complex-conjugate of the input. + + JAX implementation of :obj:`numpy.conjugate`. + + Args: + x: inpuat array or scalar. + + Returns: + An array containing the complex-conjugate of ``x``. + + See also: + - :func:`jax.numpy.real`: Returns the element-wise real part of the complex + argument. + - :func:`jax.numpy.imag`: Returns the element-wise imaginary part of the + complex argument. + + Examples: + >>> jnp.conjugate(3) + Array(3, dtype=int32, weak_type=True) + >>> x = jnp.array([2-1j, 3+5j, 7]) + >>> jnp.conjugate(x) + Array([2.+1.j, 3.-5.j, 7.-0.j], dtype=complex64) + """ check_arraylike("conjugate", x) return lax.conj(x) if np.iscomplexobj(x) else lax.asarray(x) -conj = conjugate -@implements(np.imag) +def conj(x: ArrayLike, /) -> Array: + """Alias of :func:`jax.numpy.conjugate`""" + return conjugate(x) + + @partial(jit, inline=True) def imag(val: ArrayLike, /) -> Array: + """Return element-wise imaginary of part of the complex argument. + + JAX implementation of :obj:`numpy.imag`. + + Args: + val: input array or scalar. + + Returns: + An array containing the imaginary part of the elements of ``val``. + + See also: + - :func:`jax.numpy.conjugate` and :func:`jax.numpy.conj`: Returns the element-wise + complex-conjugate of the input. + - :func:`jax.numpy.real`: Returns the element-wise real part of the complex + argument. + + Examples: + >>> jnp.imag(4) + Array(0, dtype=int32, weak_type=True) + >>> jnp.imag(5j) + Array(5., dtype=float32, weak_type=True) + >>> x = jnp.array([2+3j, 5-1j, -3]) + >>> jnp.imag(x) + Array([ 3., -1., 0.], dtype=float32) + """ check_arraylike("imag", val) return lax.imag(val) if np.iscomplexobj(val) else lax.full_like(val, 0) -@implements(np.real) @partial(jit, inline=True) def real(val: ArrayLike, /) -> Array: + """Return element-wise real part of the complex argument. + + JAX implementation of :obj:`numpy.real`. + + Args: + val: input array or scalar. + + Returns: + An array containing the real part of the elements of ``val``. + + See also: + - :func:`jax.numpy.conjugate` and :func:`jax.numpy.conj`: Returns the element-wise + complex-conjugate of the input. + - :func:`jax.numpy.imag`: Returns the element-wise imaginary part of the + complex argument. + + Examples: + >>> jnp.real(5) + Array(5, dtype=int32, weak_type=True) + >>> jnp.real(2j) + Array(0., dtype=float32, weak_type=True) + >>> x = jnp.array([3-2j, 4+7j, -2j]) + >>> jnp.real(x) + Array([ 3., 4., -0.], dtype=float32) + """ check_arraylike("real", val) return lax.real(val) if np.iscomplexobj(val) else lax.asarray(val)