-
Notifications
You must be signed in to change notification settings - Fork 506
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FIX] Refactor the function utils.cost_normalization
to work with multiple backends
#472
Conversation
…ts this method in the Numpy, Pytorch, Jax and Cupy backends
utils.cost_normalization
to work with multiple backendsutils.cost_normalization
to work with multiple backends
Hello this is a pain for tensorflow, You shouldl the numpy median and then convert back to tensorflow afterward in the backend. You loose differentiability but median is non differentiable anyways and it is not necessary for DA methods and it will work. maybe just do a warning with wrarning.warning that median in computed using numpy and the array is detached in the tf backdn function. |
I see that the error raised in the test was due to an incompatibility of the version with Pytorch. Maybe with a similar strategy used in Tensorflow will work? EDIT: I see that the |
…ackend change using numpy
…her to use torch.quantile or numpy
� Conflicts: � RELEASES.md
Types of changes
utils.cost_normalization
function now uses backends in a general way.Motivation and context / Related issue
The motivation for making this change is that currently, when using a backend other than Numpy (e.g. PyTorch), the program crashes when trying to use the
utils.cost_normalization
function. This function is also essential in theot.da.SinkhornTransport
class, for example.This PR fix the issue #465
How has this been tested (if it applies)
I modified the existing
test_cost_normalization
test to use the different types of backends, and then passed them back to Numpy to keep the tests that existed before.In addition, I have added new cases to the backend tests, to include the new
median
function.PR checklist