Skip to content

Commit

Permalink
Merge pull request jax-ml#23563 from rajasekharporeddy:testbranch1
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 675973225
  • Loading branch information
Google-ML-Automation committed Sep 18, 2024
2 parents 4e6f690 + 2714469 commit 48d8fce
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 8 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.

## jax 0.4.34

* Deprecations
* In {func}`jax.numpy.trim_zeros`, non-arraylike arguments or arraylike arguments
with `ndim != 1` are now deprecated, and in the future will result in an error.

* Deletion:
* `jax.xla_computation` is deleted. It's been 3 months since it's deprecation
in 0.4.30 JAX release.
Expand Down
1 change: 1 addition & 0 deletions jax/_src/deprecations.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,4 @@ def warn(deprecation_id: str, message: str, stacklevel: int) -> None:
register('jax-numpy-linalg-matrix_rank-tol')
register('jax-numpy-linalg-pinv-rcond')
register('jax-numpy-quantile-interpolation')
register('jax-numpy-trimzeros-not-1d-array')
28 changes: 20 additions & 8 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7018,7 +7018,7 @@ def diagflat(v: ArrayLike, k: int = 0) -> Array:
return res


def trim_zeros(filt, trim='fb'):
def trim_zeros(filt: ArrayLike, trim: str ='fb') -> Array:
"""Trim leading and/or trailing zeros of the input array.
JAX implementation of :func:`numpy.trim_zeros`.
Expand All @@ -7040,14 +7040,26 @@ def trim_zeros(filt, trim='fb'):
>>> jnp.trim_zeros(x)
Array([2, 0, 1, 4, 3], dtype=int32)
"""
filt = core.concrete_or_error(asarray, filt,
"Error arose in the `filt` argument of trim_zeros()")
nz = (filt == 0)
# Non-array inputs are deprecated 2024-09-11
util.check_arraylike("trim_zeros", filt, emit_warning=True)
core.concrete_or_error(None, filt,
"Error arose in the `filt` argument of trim_zeros()")
filt_arr = jax.numpy.asarray(filt)
del filt
if filt_arr.ndim != 1:
# Added on 2024-09-11
if deprecations.is_accelerated("jax-numpy-trimzeros-not-1d-array"):
raise TypeError(f"'filt' must be 1-D array, but received {filt_arr.ndim}-D array.")
warnings.warn(
"Passing arrays with ndim != 1 to jnp.trim_zeros() is deprecated. Currently, it "
"works with Arrays having ndim != 1. In the future this will result in an error.",
DeprecationWarning, stacklevel=2)
nz = (filt_arr == 0)
if reductions.all(nz):
return empty(0, _dtype(filt))
start = argmin(nz) if 'f' in trim.lower() else 0
end = argmin(nz[::-1]) if 'b' in trim.lower() else 0
return filt[start:len(filt) - end]
return empty(0, filt_arr.dtype)
start: Array | int = argmin(nz) if 'f' in trim.lower() else 0
end: Array | int = argmin(nz[::-1]) if 'b' in trim.lower() else 0
return filt_arr[start:len(filt_arr) - end]


def trim_zeros_tol(filt, tol, trim='fb'):
Expand Down
6 changes: 6 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1478,6 +1478,12 @@ def testTrimZeros(self, a_shape, dtype, trim):
jnp_fun = lambda arg1: jnp.trim_zeros(arg1, trim)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)

def testTrimZerosNotOneDArray(self):
# TODO: make this an error after the deprecation period.
with self.assertWarnsRegex(DeprecationWarning,
r"Passing arrays with ndim != 1 to jnp.trim_zeros\(\)"):
jnp.trim_zeros(jnp.array([[0.0, 1.0, 0.0],[2.0, 4.5, 0.0]]))

@jtu.sample_product(
rank=(1, 2),
dtype=default_dtypes,
Expand Down

0 comments on commit 48d8fce

Please sign in to comment.