Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX] Refactor the function utils.cost_normalization to work with multiple backends #472

Merged
merged 16 commits into from
May 10, 2023

Conversation

framunoz
Copy link
Contributor

@framunoz framunoz commented May 3, 2023

Types of changes

  • The utils.cost_normalization function now uses backends in a general way.
  • The backends now implements the 'median' function.

Motivation and context / Related issue

The motivation for making this change is that currently, when using a backend other than Numpy (e.g. PyTorch), the program crashes when trying to use the utils.cost_normalization function. This function is also essential in the ot.da.SinkhornTransport class, for example.

This PR fix the issue #465

How has this been tested (if it applies)

I modified the existing test_cost_normalization test to use the different types of backends, and then passed them back to Numpy to keep the tests that existed before.

In addition, I have added new cases to the backend tests, to include the new median function.

PR checklist

  • I have read the CONTRIBUTING document.
  • The documentation is up-to-date with the changes I made (check build artifacts).
  • All tests passed, and additional code has been covered with new tests.
  • I have added the PR and Issue fix to the RELEASES.md file.

@rflamary rflamary changed the title Refactor the function utils.cost_normalization to work with multiple backends [FIX] Refactor the function utils.cost_normalization to work with multiple backends May 4, 2023
@rflamary
Copy link
Collaborator

rflamary commented May 5, 2023

Hello this is a pain for tensorflow, You shouldl the numpy median and then convert back to tensorflow afterward in the backend. You loose differentiability but median is non differentiable anyways and it is not necessary for DA methods and it will work. maybe just do a warning with wrarning.warning that median in computed using numpy and the array is detached in the tf backdn function.

@framunoz
Copy link
Contributor Author

framunoz commented May 6, 2023

I see that the error raised in the test was due to an incompatibility of the version with Pytorch. Maybe with a similar strategy used in Tensorflow will work?

EDIT: I see that the interpolation parameter was added in the version 1.11.0 in Pytorch, so, the method evaluates the version first, and decides whether to use the torch.quantile function or transform it to Numpy, showing a warning in this case.

@rflamary rflamary merged commit 8cc8dd2 into PythonOT:master May 10, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants