Skip to content

Commit

Permalink
jnp.sort_complex: fix output for N-dimensional inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Sep 18, 2024
1 parent e15ec1e commit 2834c13
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 10 deletions.
13 changes: 11 additions & 2 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8844,7 +8844,8 @@ def sort_complex(a: ArrayLike) -> Array:
a: input array. If dtype is not complex, the array will be upcast to complex.
Returns:
A sorted array of the same shape and complex dtype as the input.
A sorted array of the same shape and complex dtype as the input. If ``a``
is multi-dimensional, it is sorted along the last axis.
See also:
- :func:`jax.numpy.sort`: Return a sorted copy of an array.
Expand All @@ -8853,9 +8854,17 @@ def sort_complex(a: ArrayLike) -> Array:
>>> a = jnp.array([1+2j, 2+4j, 3-1j, 2+3j])
>>> jnp.sort_complex(a)
Array([1.+2.j, 2.+3.j, 2.+4.j, 3.-1.j], dtype=complex64)
Multi-dimensional arrays are sorted along the last axis:
>>> a = jnp.array([[5, 3, 4],
... [6, 9, 2]])
>>> jnp.sort_complex(a)
Array([[3.+0.j, 4.+0.j, 5.+0.j],
[2.+0.j, 6.+0.j, 9.+0.j]], dtype=complex64)
"""
util.check_arraylike("sort_complex", a)
a = lax.sort(asarray(a), dimension=0)
a = lax.sort(asarray(a))
return lax.convert_element_type(a, dtypes.to_complex_dtype(a.dtype))

@util.implements(np.lexsort)
Expand Down
10 changes: 2 additions & 8 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4295,14 +4295,8 @@ def testSortStableDescending(self):
self.assertArraysEqual(jnp.argsort(x), argsorted_stable)
self.assertArraysEqual(jnp.argsort(x, descending=True), argsorted_rev_stable)

@jtu.sample_product(
[dict(shape=shape, axis=axis)
for shape in one_dim_array_shapes
for axis in [None]
],
dtype=all_dtypes,
)
def testSortComplex(self, dtype, shape, axis):
@jtu.sample_product(shape=nonzerodim_shapes, dtype=all_dtypes)
def testSortComplex(self, shape, dtype):
rng = jtu.rand_some_equal(self.rng())
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np.sort_complex, jnp.sort_complex, args_maker,
Expand Down

0 comments on commit 2834c13

Please sign in to comment.