From d77af7a655b17e7fbf929d34b01ed0153291b329 Mon Sep 17 00:00:00 2001 From: Selam Waktola Date: Mon, 3 Jun 2024 17:15:26 -0700 Subject: [PATCH] append_docstring_added append_docstring_modified append_doc_line_break append_doc_linting_fixed --- jax/_src/numpy/lax_numpy.py | 46 +++++++++++++++++++++++++++++++++++-- 1 file changed, 44 insertions(+), 2 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 2e222bf6c612..ae67defa8918 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -3955,12 +3955,54 @@ def trim_zeros_tol(filt, tol, trim='fb'): end = argmin(nz[::-1]) if 'b' in trim.lower() else 0 return filt[start:len(filt) - end] - -@util.implements(np.append) @partial(jit, static_argnames=('axis',)) def append( arr: ArrayLike, values: ArrayLike, axis: int | None = None ) -> Array: + """Return a new array with values appended to the end of the original array. + + JAX implementation of :func:`numpy.append`. + + Args: + arr: original array. + values: values to be appended to the array. The ``values`` must have + the same number of dimensions as ``arr``, and all dimensions must + match except in the specified axis. + axis: axis along which to append values. If None (default), both ``arr`` + and ``values`` will be flattened before appending. + + Returns: + A new array with values appended to ``arr``. + + See also: + - :func:`jax.numpy.insert` + - :func:`jax.numpy.delete` + + Examples: + >>> a = jnp.array([1, 2, 3]) + >>> b = jnp.array([4, 5, 6]) + >>> jnp.append(a, b) + Array([1, 2, 3, 4, 5, 6], dtype=int32) + + Appending along a specific axis: + + >>> a = jnp.array([[1, 2], + ... [3, 4]]) + >>> b = jnp.array([[5, 6]]) + >>> jnp.append(a, b, axis=0) + Array([[1, 2], + [3, 4], + [5, 6]], dtype=int32) + + Appending along a trailing axis: + + >>> a = jnp.array([[1, 2, 3], + ... [4, 5, 6]]) + >>> b = jnp.array([[7], [8]]) + >>> jnp.append(a, b, axis=1) + Array([[1, 2, 3, 7], + [4, 5, 6, 8]], dtype=int32) + """ if axis is None: return concatenate([ravel(arr), ravel(values)], 0) else: