Skip to content

Commit

Permalink
Improve docs for jax.numpy: conjugate, conj, imag and real
Browse files Browse the repository at this point in the history
  • Loading branch information
rajasekharporeddy committed Sep 21, 2024
1 parent 6a5553d commit c55cd37
Showing 1 changed file with 78 additions and 4 deletions.
82 changes: 78 additions & 4 deletions jax/_src/numpy/ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2163,24 +2163,98 @@ def radians(x: ArrayLike, /) -> Array:
return deg2rad(x)


@implements(np.conjugate, module='numpy')
@partial(jit, inline=True)
def conjugate(x: ArrayLike, /) -> Array:
"""Return element-wise complex-conjugate of the input.
JAX implementation of :obj:`numpy.conjugate`.
Args:
x: inpuat array or scalar.
Returns:
An array containing the complex-conjugate of ``x``.
See also:
- :func:`jax.numpy.real`: Returns the element-wise real part of the complex
argument.
- :func:`jax.numpy.imag`: Returns the element-wise imaginary part of the
complex argument.
Examples:
>>> jnp.conjugate(3)
Array(3, dtype=int32, weak_type=True)
>>> x = jnp.array([2-1j, 3+5j, 7])
>>> jnp.conjugate(x)
Array([2.+1.j, 3.-5.j, 7.-0.j], dtype=complex64)
"""
check_arraylike("conjugate", x)
return lax.conj(x) if np.iscomplexobj(x) else lax.asarray(x)
conj = conjugate


@implements(np.imag)
def conj(x: ArrayLike, /) -> Array:
"""Alias of :func:`jax.numpy.conjugate`"""
return conjugate(x)


@partial(jit, inline=True)
def imag(val: ArrayLike, /) -> Array:
"""Return element-wise imaginary of part of the complex argument.
JAX implementation of :obj:`numpy.imag`.
Args:
val: input array or scalar.
Returns:
An array containing the imaginary part of the elements of ``val``.
See also:
- :func:`jax.numpy.conjugate` and :func:`jax.numpy.conj`: Returns the element-wise
complex-conjugate of the input.
- :func:`jax.numpy.real`: Returns the element-wise real part of the complex
argument.
Examples:
>>> jnp.imag(4)
Array(0, dtype=int32, weak_type=True)
>>> jnp.imag(5j)
Array(5., dtype=float32, weak_type=True)
>>> x = jnp.array([2+3j, 5-1j, -3])
>>> jnp.imag(x)
Array([ 3., -1., 0.], dtype=float32)
"""
check_arraylike("imag", val)
return lax.imag(val) if np.iscomplexobj(val) else lax.full_like(val, 0)


@implements(np.real)
@partial(jit, inline=True)
def real(val: ArrayLike, /) -> Array:
"""Return element-wise real part of the complex argument.
JAX implementation of :obj:`numpy.real`.
Args:
val: input array or scalar.
Returns:
An array containing the real part of the elements of ``val``.
See also:
- :func:`jax.numpy.conjugate` and :func:`jax.numpy.conj`: Returns the element-wise
complex-conjugate of the input.
- :func:`jax.numpy.imag`: Returns the element-wise imaginary part of the
complex argument.
Examples:
>>> jnp.real(5)
Array(5, dtype=int32, weak_type=True)
>>> jnp.real(2j)
Array(0., dtype=float32, weak_type=True)
>>> x = jnp.array([3-2j, 4+7j, -2j])
>>> jnp.real(x)
Array([ 3., 4., -0.], dtype=float32)
"""
check_arraylike("real", val)
return lax.real(val) if np.iscomplexobj(val) else lax.asarray(val)

Expand Down

0 comments on commit c55cd37

Please sign in to comment.