-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: pad
: add delegation
#72
Conversation
f90cb78
to
d17fd2f
Compare
58ed3dc
to
38690bb
Compare
@ev-br this is what I'm thinking. Of course, for only 1 function here this overcomplicates things, but I think we will see the benefits of cleanly separating which files are using the standard API and which are using library-specific functions, if/when more delegation is added. And it leaves the door open for more "smart" delegation if that would become appropriate. I'd like to leave this as draft until gh-53 is merged, as that PR adds a way to test with existing array libraries (and this one would cause merge conflicts anyway if we did that here first). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the main points are:
- you effectively whitelist the kernels as opposed to a catch-all delegating to
xp.pad
. An unknown backend will get a 'generic' implementation instead of being asked for apad
namespaced function. Makes sense. - the approach is way more light-touch than in scipy.{signal,ndimage,special}: in there you need either a delegation dict or a whole module. Here it's simple enough to keep the if-elif chain inside an affected function. Would be straightforward to refactor if a need arises.
A really minor point is the _delegators.py
name: in scipy this only contains the _signature
functions, and the kernels are elsewhere. Here this module contains the public symbol(s) including the if-elif delegation chains. This is fine though, and will also be simple to refactor if the need arises (that's a big fat if here).
My other inline comments are all optional and minor.
Concluding, this refactor LGTM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TO-DO here:
- adjust scope section of the docs
- rename
_delegators.py
to_delegation.py
- add
pad
to the api reference - figure out where the function docstrings should live
- test the delegation
follow-up:
- add a section on adding delegation to the contributor docs
- open a tracker for adding delegation to functions
Tried using this on scipy/scipy#22226 and apparently there's an issue with the 'generic' implementation: downstream usage involves
|
@ev-br could you take another look at the delegation layer here now that the |
With this patch, $ git diff
diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py
index 8ecff02..b8b6423 100644
--- a/src/array_api_extra/_delegation.py
+++ b/src/array_api_extra/_delegation.py
@@ -117,7 +117,7 @@ def pad(
pad_width = xp.asarray(pad_width)
pad_width = xp.broadcast_to(pad_width, (x.ndim, 2))
pad_width = xp.flip(pad_width, axis=(0,)).flatten()
- return xp.nn.functional.pad(x, (pad_width,), value=constant_values)
+ return xp.nn.functional.pad(x, tuple(pad_width), value=constant_values)
if _delegate(xp, NUMPY, JAX, CUPY):
return xp.pad(x, pad_width, mode, constant_values=constant_values)
diff --git a/tests/test_funcs.py b/tests/test_funcs.py
index bada663..2a5a122 100644
--- a/tests/test_funcs.py
+++ b/tests/test_funcs.py
@@ -1,8 +1,9 @@
import contextlib
import warnings
-# data-apis/array-api-strict#6
-import array_api_strict as xp # type: ignore[import-untyped] # pyright: ignore[reportMissingTypeStubs]
+import torch as xp
import numpy as np
import pytest
from numpy.testing import assert_allclose, assert_array_equal, assert_equal
@@ -20,6 +21,7 @@ from array_api_extra import (
from array_api_extra._lib._utils._typing import Array
+
class TestAtLeastND:
def test_0D(self):
x = xp.asarray(1)
@@ -415,6 +417,10 @@ class TestPad:
assert pad(a, 2).device == device
def test_xp(self):
assert_array_equal(pad(xp.asarray(0), 1, xp=xp), xp.zeros(3))
def test_tuple_width(self):
@@ -425,7 +431,7 @@ class TestPad:
padded = pad(a, (1, 2))
assert padded.shape == (6, 7)
- with pytest.raises(ValueError, match="expect a 2-tuple"):
+ with pytest.raises((ValueError, RuntimeError)): #, match="expect a 2-tuple"): NB: exc types and messages differ!
pad(a, [(1, 2, 3)]) # type: ignore[list-item] # pyright: ignore[reportArgumentType]
def test_list_of_tuples_width(self):
I see two failures:
The second one is interesting: since the torch code path is called with an explicit |
If |
Looks like delegation is not working at all: To test, apply the following diff to scipy/scipy#22122 diff --git a/scipy/signal/_signaltools.py b/scipy/signal/_signaltools.py
index c208278744..eb39956417 100644
--- a/scipy/signal/_signaltools.py
+++ b/scipy/signal/_signaltools.py
@@ -871,6 +871,13 @@ def _split(x, indices_or_sections, axis, xp):
def xp_pad(x, pad_width, mode='constant', *, xp, **kwargs):
+
+ breakpoint()
+
+ import array_api_extra as xpx
+ return xpx.pad(x, pad_width, mode=mode, xp=xp, **kwargs)
+
+ then pip editable install
|
I can't reproduce immediately:
|
the implementation of return xp.__name__ in {'torch', _compat_module_name() + '.torch'}
...
def _compat_module_name():
assert __name__.endswith('.common._helpers')
return __name__.removesuffix('.common._helpers') My inkling is that this fails for you because you're mixing a namespace from |
To me it looks like a bad form to depend on how a package is installed. |
Sounds good if you can find a way to stop depending on that, but I don't think it needs to block this. |
Well, I certainly don't want to block this. If you think it's fine and your testing of scipy shows that it works for the vendored copy, great, let's roll with it and fix issues when found. It'll be quite a bit easier when it's released, so there's no need for hacks to even test this. |
Yep, makes sense. We should still run some torch/jax tests over this at least in CI here before releasing. |
@ev-br any chance you could take a look at the CI failures? |
@crusaderky does |
AFAICS, all CI failures come from testing for an error
and simply signal the fact that the So a fix is simply to relax the regex or drop it altogether. You're testing for an error and an error you get. End of story, move on, I'd say. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm sorry to be blunt, but I have to give a solid -1 to this.
The whole PR feels very over-engineered. I think that the current approach of at
and nunique
, which is to have a special-case code path for selected backends, is a lot more readable and maintainable than this, without having any drawbacks short of ideological purity.
at
specifically shows how you can't implement this pattern with it, as the special-case code is deep inside a private method, so you'd have to write a bunch of very bespoke private functions that pinball around different modules of array-api-extra
in order to make it conform to the pattern.
@@ -1,4 +1,4 @@ | |||
"""Public API Functions.""" | |||
"""Array-agnostic implementations for the public API.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Except.... it isn't agnostic, see for example the special paths in at
and nunique
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, I would like to split the file structure so that functions which make use of special paths are separate from array-agnostic implementations. I'll save that for a follow-up.
""" | ||
is_namespace_func = getattr(_compat, f"is_{self.library_name}_namespace") | ||
is_namespace_func = cast(Callable[[ModuleType], bool], is_namespace_func) | ||
return is_namespace_func(xp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this whole class feels quite over-engineered.
Maybe we could have two simple functions that read from private mappings:
def is_namespace(xp: ModuleType, library: Library) -> bool: ...
def import(library: Library) -> ModuleType: ...
We should experiment with making a fake namespace for numpy_readonly too (asarray returns a read-only array, everything else is redirected to array_api_compat.numpy)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice suggestion with is_namespace
, thanks, done. I'm not sure how to cleanly implement import
in a way that reduces over-engineering. Let me know if you know how.
We should experiment with making a fake namespace for numpy_readonly too (asarray returns a read-only array, everything else is redirected to array_api_compat.numpy)
sounds cool for a follow-up
from typing import Any | ||
|
||
# To be changed to a Protocol later (see data-apis/array-api#589) | ||
Array = Any # type: ignore[no-any-explicit] | ||
Device = Any # type: ignore[no-any-explicit] | ||
Index = Any # type: ignore[no-any-explicit] | ||
|
||
__all__ = ["Array", "Device", "Index", "ModuleType"] | ||
__all__ = ["Array", "Device", "Index"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I liked this before. If anything, we should rename it to ArrayModuleType
.
I would hope that eventually a new library array_api_types
defines what functions exactly an array api compatible module must declare.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, this should be fixed eventually by array-api-typing. In the meantime, feel free to submit a PR changing use of ModuleType
to an ArrayNamespace
alias.
Yeah, I should clarify that these two patterns are orthogonal. This delegation was motivated by @ev-br's wish to use functions directly from the underlying backend whenever there is a (pretty much) one-to-one replacement. For functions like |
Co-authored-by: Guido Imperiale <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the help @ev-br @crusaderky, once I've updated the scope section of the docs this is ready from my side. Follow-ups directly related to this are tracked in the PR description.
Ship it! :-) |
Great. Now it'd be helpful to make a release or a tag so that scipy can bump the submodule to a released version. |
release is done! |
closes gh-69
TO-DO here:
_delegators.py
to_delegation.py
pad
to the api referencefollow-up:
_funcs.py