Skip to content

Commit

Permalink
Merge pull request #23586 from rajasekharporeddy:testbranch3
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 673976796
  • Loading branch information
Google-ML-Automation committed Sep 12, 2024
2 parents de9b98e + 5234173 commit 522ad79
Showing 1 changed file with 32 additions and 1 deletion.
33 changes: 32 additions & 1 deletion jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1774,9 +1774,40 @@ def unravel_index(indices: ArrayLike, shape: Shape) -> tuple[Array, ...]:
return tuple(where(oob_pos, s - 1, where(oob_neg, 0, i))
for s, i in safe_zip(shape, out_indices))

@util.implements(np.resize)

@partial(jit, static_argnames=('new_shape',))
def resize(a: ArrayLike, new_shape: Shape) -> Array:
"""Return a new array with specified shape.
JAX implementation of :func:`numpy.resize`.
Args:
a: input array or scalar.
new_shape: int or tuple of ints. Specifies the shape of the resized array.
Returns:
A resized array with specified shape. The elements of ``a`` are repeated in
the resized array, if the resized array is larger than the original aray.
See also:
- :func:`jax.numpy.reshape`: Returns a reshaped copy of an array.
- :func:`jax.numpy.repeat`: Constructs an array from repeated elements.
Examples:
>>> x = jnp.array([1, 2, 3, 4, 5, 6, 7, 8, 9])
>>> jnp.resize(x, (3, 3))
Array([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]], dtype=int32)
>>> jnp.resize(x, (3, 4))
Array([[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 1, 2, 3]], dtype=int32)
>>> jnp.resize(4, (3, 2))
Array([[4, 4],
[4, 4],
[4, 4]], dtype=int32, weak_type=True)
"""
util.check_arraylike("resize", a)
new_shape = _ensure_index_tuple(new_shape)

Expand Down

0 comments on commit 522ad79

Please sign in to comment.