Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: pad: add delegation #72

Merged
merged 25 commits into from
Jan 15, 2025
Merged

Conversation

lucascolley
Copy link
Member

@lucascolley lucascolley commented Dec 26, 2024

closes gh-69

TO-DO here:

  • adjust scope section of the docs
  • rename _delegators.py to _delegation.py
  • add pad to the api reference
  • figure out where the function docstrings should live
  • test the delegation

follow-up:

  • add a section on adding delegation to the contributor docs
  • open a tracker for adding delegation to functions
  • split non-agnostic functions out from _funcs.py

@lucascolley lucascolley changed the title ENH: add pad ENH: pad: add delegation Dec 26, 2024
@lucascolley
Copy link
Member Author

lucascolley commented Dec 26, 2024

@ev-br this is what I'm thinking. Of course, for only 1 function here this overcomplicates things, but I think we will see the benefits of cleanly separating which files are using the standard API and which are using library-specific functions, if/when more delegation is added. And it leaves the door open for more "smart" delegation if that would become appropriate.

I'd like to leave this as draft until gh-53 is merged, as that PR adds a way to test with existing array libraries (and this one would cause merge conflicts anyway if we did that here first).

@lucascolley lucascolley added enhancement New feature or request delegation labels Dec 26, 2024
Copy link
Member

@ev-br ev-br left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the main points are:

  • you effectively whitelist the kernels as opposed to a catch-all delegating to xp.pad. An unknown backend will get a 'generic' implementation instead of being asked for a pad namespaced function. Makes sense.
  • the approach is way more light-touch than in scipy.{signal,ndimage,special}: in there you need either a delegation dict or a whole module. Here it's simple enough to keep the if-elif chain inside an affected function. Would be straightforward to refactor if a need arises.

A really minor point is the _delegators.py name: in scipy this only contains the _signature functions, and the kernels are elsewhere. Here this module contains the public symbol(s) including the if-elif delegation chains. This is fine though, and will also be simple to refactor if the need arises (that's a big fat if here).

My other inline comments are all optional and minor.

Concluding, this refactor LGTM.

src/array_api_extra/_delegators.py Outdated Show resolved Hide resolved
src/array_api_extra/_lib/_utils/_compat.py Show resolved Hide resolved
Copy link
Member Author

@lucascolley lucascolley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TO-DO here:

  • adjust scope section of the docs
  • rename _delegators.py to _delegation.py
  • add pad to the api reference
  • figure out where the function docstrings should live
  • test the delegation

follow-up:

  • add a section on adding delegation to the contributor docs
  • open a tracker for adding delegation to functions

@ev-br
Copy link
Member

ev-br commented Jan 7, 2025

Tried using this on scipy/scipy#22226 and apparently there's an issue with the 'generic' implementation: downstream usage involves pad_width being a tuple:

> /home/br/repos/scipy/scipy/build-install/lib/python3.12/site-packages/scipy/signal/tests/test_upfirdn.py(147)test_singleton()
-> want = xpx.pad(x, (len_h // 2, (len_h - 1) // 2), 'constant', xp=xp)
(Pdb) p x.shape
(1,)
(Pdb) p len_h
1
(Pdb) n
TypeError: unsupported operand type(s) for +: 'int' and 'tuple'

@lucascolley
Copy link
Member Author

@ev-br could you take another look at the delegation layer here now that the pad_width param has changed?

@ev-br
Copy link
Member

ev-br commented Jan 7, 2025

With this patch,

$ git diff
diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py
index 8ecff02..b8b6423 100644
--- a/src/array_api_extra/_delegation.py
+++ b/src/array_api_extra/_delegation.py
@@ -117,7 +117,7 @@ def pad(
         pad_width = xp.asarray(pad_width)
         pad_width = xp.broadcast_to(pad_width, (x.ndim, 2))
         pad_width = xp.flip(pad_width, axis=(0,)).flatten()
-        return xp.nn.functional.pad(x, (pad_width,), value=constant_values)
+        return xp.nn.functional.pad(x, tuple(pad_width), value=constant_values)
 
     if _delegate(xp, NUMPY, JAX, CUPY):
         return xp.pad(x, pad_width, mode, constant_values=constant_values)
diff --git a/tests/test_funcs.py b/tests/test_funcs.py
index bada663..2a5a122 100644
--- a/tests/test_funcs.py
+++ b/tests/test_funcs.py
@@ -1,8 +1,9 @@
 import contextlib
 import warnings
 
-# data-apis/array-api-strict#6
-import array_api_strict as xp  # type: ignore[import-untyped]  # pyright: ignore[reportMissingTypeStubs]
+import torch as xp
 import numpy as np
 import pytest
 from numpy.testing import assert_allclose, assert_array_equal, assert_equal
@@ -20,6 +21,7 @@ from array_api_extra import (
 from array_api_extra._lib._utils._typing import Array
 
 
+
 class TestAtLeastND:
     def test_0D(self):
         x = xp.asarray(1)
@@ -415,6 +417,10 @@ class TestPad:
         assert pad(a, 2).device == device
 
     def test_xp(self):
         assert_array_equal(pad(xp.asarray(0), 1, xp=xp), xp.zeros(3))
 
     def test_tuple_width(self):
@@ -425,7 +431,7 @@ class TestPad:
         padded = pad(a, (1, 2))
         assert padded.shape == (6, 7)
 
-        with pytest.raises(ValueError, match="expect a 2-tuple"):
+        with pytest.raises((ValueError, RuntimeError)): #, match="expect a 2-tuple"):  NB: exc types and messages differ!
             pad(a, [(1, 2, 3)])  # type: ignore[list-item]  # pyright: ignore[reportArgumentType]
 
     def test_list_of_tuples_width(self):

I see two failures:

FAILED tests/test_funcs.py::TestPad::test_device - AttributeError: module 'torch' has no attribute 'Device'
FAILED tests/test_funcs.py::TestPad::test_xp - TypeError: flip() missing 1 required positional arguments: "dims"

The second one is interesting: since the torch code path is called with an explicit xp=xp, it gets an unwrapped torch namespace. Not sure if that's what you had in mind: use array_api_compar.torch or the unwrapped torch.

@lucascolley
Copy link
Member Author

The second one is interesting: since the torch code path is called with an explicit xp=xp, it gets an unwrapped torch namespace. Not sure if that's what you had in mind: use array_api_compar.torch or the unwrapped torch.

If xp is passed it should be compatible, this is documented at https://data-apis.org/array-api-extra/#usage. So when the tests are expanded to alternative backends, the test_xp tests should call array_namespace first.

@ev-br
Copy link
Member

ev-br commented Jan 8, 2025

could you take another look at the delegation layer here now that the pad_width param has changed?

Looks like delegation is not working at all: _delegate(xp) returns False.

To test, apply the following diff to scipy/scipy#22122

diff --git a/scipy/signal/_signaltools.py b/scipy/signal/_signaltools.py
index c208278744..eb39956417 100644
--- a/scipy/signal/_signaltools.py
+++ b/scipy/signal/_signaltools.py
@@ -871,6 +871,13 @@ def _split(x, indices_or_sections, axis, xp):
 
 
 def xp_pad(x, pad_width, mode='constant', *, xp, **kwargs):
+
+    breakpoint()
+
+    import array_api_extra as xpx
+    return xpx.pad(x, pad_width, mode=mode, xp=xp, **kwargs)
+
+

then pip editable install array_api_extra from this PR, and

$ python dev.py test -t scipy/signal/tests/test_signaltools.py::TestOAConvolve -v -b torch
💻  ninja -C /home/br/repos/scipy/scipy/build -j4
... snip ...

scipy/signal/tests/test_signaltools.py::TestOAConvolve::test_1d_noaxes[torch-full-True-50-6] 
>>>>>>>>>>>>>>>>>>>>>>> PDB set_trace (IO-capturing turned off) >>>>>>>>>>>>>>>>>>>>>>>>
> /home/br/repos/scipy/scipy/build-install/lib/python3.12/site-packages/scipy/signal/_signaltools.py(877)xp_pad()
-> import array_api_extra as xpx
(Pdb) n
> /home/br/repos/scipy/scipy/build-install/lib/python3.12/site-packages/scipy/signal/_signaltools.py(878)xp_pad()
-> return xpx.pad(x, pad_width, mode=mode, xp=xp, **kwargs)
(Pdb) s
--Call--
> /home/br/repos/array-api-extra/src/array_api_extra/_delegation.py(74)pad()
-> def pad(
(Pdb) n
> /home/br/repos/array-api-extra/src/array_api_extra/_delegation.py(109)pad()
-> xp = array_namespace(x) if xp is None else xp
(Pdb) n
> /home/br/repos/array-api-extra/src/array_api_extra/_delegation.py(111)pad()
-> if mode != "constant":
(Pdb) n
> /home/br/repos/array-api-extra/src/array_api_extra/_delegation.py(116)pad()
-> if _delegate(xp, TORCH):
(Pdb) p _delegate(xp, TORCH)
False
(Pdb) p xp
<module 'scipy._lib.array_api_compat.torch' from '/home/br/repos/scipy/scipy/build-install/lib/python3.12/site-packages/scipy/_lib/array_api_compat/torch/__init__.py'>
(Pdb)

@lucascolley
Copy link
Member Author

I can't reproduce immediately:

array-api-extra on  pad-delegate [$] via  v3.10.10 took 9s 
❯ pixi r ipython  
✨ Pixi task (ipython in dev): ipython
Python 3.12.8 | packaged by conda-forge | (main, Dec  5 2024, 14:19:53) [Clang 18.1.8 ]
Type 'copyright', 'credits' or 'license' for more information
IPython 8.31.0 -- An enhanced Interactive Python. Type '?' for help.

In [1]: from array_api_extra._delegation import _delegate

In [2]: import torch

In [3]: from array_api_extra._delegation import TORCH

In [4]: TORCH
Out[4]: <IsNamespace.TORCH: functools.partial(<function is_torch_namespace at 0x105695a80>)>

In [5]: TORCH.value
Out[5]: functools.partial(<function is_torch_namespace at 0x105695a80>)

In [6]: TORCH.value(torch)
Out[6]: True

In [7]: _delegate(torch, TORCH)
Out[7]: True

In [8]: from array_api_compat import torch

In [9]: _delegate(torch, TORCH)
Out[9]: True

@lucascolley
Copy link
Member Author

lucascolley commented Jan 8, 2025

the implementation of is_torch_namespace:

return xp.__name__ in {'torch', _compat_module_name() + '.torch'}
...
def _compat_module_name():
    assert __name__.endswith('.common._helpers')
    return __name__.removesuffix('.common._helpers')

My inkling is that this fails for you because you're mixing a namespace from scipy._lib.array_api_compat with pip installed array-api-extra. array-api-extra requires array-api-compat to be vendored iff array-api-extra itself is.

@ev-br
Copy link
Member

ev-br commented Jan 8, 2025

To me it looks like a bad form to depend on how a package is installed.

@lucascolley
Copy link
Member Author

Sounds good if you can find a way to stop depending on that, but I don't think it needs to block this.

@ev-br
Copy link
Member

ev-br commented Jan 8, 2025

Well, I certainly don't want to block this. If you think it's fine and your testing of scipy shows that it works for the vendored copy, great, let's roll with it and fix issues when found. It'll be quite a bit easier when it's released, so there's no need for hacks to even test this.
If this works for both everything external or everything vendored, and fails on mixing the two, I agree that that's fine indeed.

@lucascolley
Copy link
Member Author

Yep, makes sense. We should still run some torch/jax tests over this at least in CI here before releasing.

@lucascolley
Copy link
Member Author

@ev-br any chance you could take a look at the CI failures?

@lucascolley
Copy link
Member Author

@crusaderky does e0046c5 (#72) look okay to you? Maybe it is too much to have 4 different fields, name, value, library_name and module_name?

@ev-br
Copy link
Member

ev-br commented Jan 15, 2025

any chance you could take a look at the CI failures?

AFAICS, all CI failures come from testing for an error

        with pytest.raises(ValueError, match="expect a 2-tuple"):
>           pad(a, [(1, 2, 3)])  # type: ignore[list-item]  # pyright: ignore[reportArgumentType]

and simply signal the fact that the match regex is too strict. Different backends have differently worded messages, and there's no reason to expect that all of them have the same wording.

So a fix is simply to relax the regex or drop it altogether. You're testing for an error and an error you get. End of story, move on, I'd say.
I wouldn't be surprised is e.g. torch emits a RuntimeError, in which case I'd just do pytest.raises((ValueError, RuntimeError)).
Of course, you can add a try-except and remap the error but why bothering.
The thing looks a bit over-engineered as is.

Copy link
Contributor

@crusaderky crusaderky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sorry to be blunt, but I have to give a solid -1 to this.

The whole PR feels very over-engineered. I think that the current approach of at and nunique, which is to have a special-case code path for selected backends, is a lot more readable and maintainable than this, without having any drawbacks short of ideological purity.

at specifically shows how you can't implement this pattern with it, as the special-case code is deep inside a private method, so you'd have to write a bunch of very bespoke private functions that pinball around different modules of array-api-extra in order to make it conform to the pattern.

src/array_api_extra/_lib/_funcs.py Show resolved Hide resolved
@@ -1,4 +1,4 @@
"""Public API Functions."""
"""Array-agnostic implementations for the public API."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Except.... it isn't agnostic, see for example the special paths in at and nunique

Copy link
Member Author

@lucascolley lucascolley Jan 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I would like to split the file structure so that functions which make use of special paths are separate from array-agnostic implementations. I'll save that for a follow-up.

src/array_api_extra/_lib/_libraries.py Outdated Show resolved Hide resolved
src/array_api_extra/_lib/_libraries.py Outdated Show resolved Hide resolved
"""
is_namespace_func = getattr(_compat, f"is_{self.library_name}_namespace")
is_namespace_func = cast(Callable[[ModuleType], bool], is_namespace_func)
return is_namespace_func(xp)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this whole class feels quite over-engineered.
Maybe we could have two simple functions that read from private mappings:

def is_namespace(xp: ModuleType, library: Library) -> bool: ...
def import(library: Library) -> ModuleType: ...

We should experiment with making a fake namespace for numpy_readonly too (asarray returns a read-only array, everything else is redirected to array_api_compat.numpy)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice suggestion with is_namespace, thanks, done. I'm not sure how to cleanly implement import in a way that reduces over-engineering. Let me know if you know how.

We should experiment with making a fake namespace for numpy_readonly too (asarray returns a read-only array, everything else is redirected to array_api_compat.numpy)

sounds cool for a follow-up

from typing import Any

# To be changed to a Protocol later (see data-apis/array-api#589)
Array = Any # type: ignore[no-any-explicit]
Device = Any # type: ignore[no-any-explicit]
Index = Any # type: ignore[no-any-explicit]

__all__ = ["Array", "Device", "Index", "ModuleType"]
__all__ = ["Array", "Device", "Index"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I liked this before. If anything, we should rename it to ArrayModuleType.
I would hope that eventually a new library array_api_types defines what functions exactly an array api compatible module must declare.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, this should be fixed eventually by array-api-typing. In the meantime, feel free to submit a PR changing use of ModuleType to an ArrayNamespace alias.

tests/conftest.py Show resolved Hide resolved
src/array_api_extra/_delegation.py Outdated Show resolved Hide resolved
src/array_api_extra/_delegation.py Outdated Show resolved Hide resolved
src/array_api_extra/_delegation.py Outdated Show resolved Hide resolved
@lucascolley
Copy link
Member Author

at specifically shows how you can't implement this pattern with it, as the special-case code is deep inside a private method, so you'd have to write a bunch of very bespoke private functions that pinball around different modules of array-api-extra in order to make it conform to the pattern.

Yeah, I should clarify that these two patterns are orthogonal. This delegation was motivated by @ev-br's wish to use functions directly from the underlying backend whenever there is a (pretty much) one-to-one replacement. For functions like at which are very involved with backend-specific paths throughout, they can stay separate from this delegation. As mentioned above, we should split up the file structure to reflect this in a follow-up.

Copy link
Member Author

@lucascolley lucascolley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the help @ev-br @crusaderky, once I've updated the scope section of the docs this is ready from my side. Follow-ups directly related to this are tracked in the PR description.

@ev-br
Copy link
Member

ev-br commented Jan 15, 2025

Ship it! :-)

@lucascolley lucascolley marked this pull request as ready for review January 15, 2025 12:55
@lucascolley lucascolley merged commit 290ebb5 into data-apis:main Jan 15, 2025
10 checks passed
@lucascolley lucascolley deleted the pad-delegate branch January 15, 2025 13:06
@ev-br
Copy link
Member

ev-br commented Jan 15, 2025

Great. Now it'd be helpful to make a release or a tag so that scipy can bump the submodule to a released version.

@lucascolley
Copy link
Member Author

release is done!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
delegation enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ENH: add pad
3 participants