Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DA] Sinkhorn L1L2 transport to work on JAX #587

Merged
merged 18 commits into from
Dec 22, 2023

Conversation

kachayev
Copy link
Collaborator

@kachayev kachayev commented Dec 3, 2023

Types of changes

All operations required for SinkhornL1l2Transport to work on JAX are properly vectorized, including those implemented in the BaseTransport. In short

  • per-labels for-loops are vectorized using mask tensors, implementation is moved to a labels_to_masks helper with corresponding tests
  • a new backend method nan_to_num
  • JAX backend was removed from the exclusion list in BaseEstimator
  • a few enhancements for label normalization and related operations (including avoid unnecessary computations when normalizing labels)

Motivation and context / Related issue

The next step towards making domain adaptation methods to work on JAX backend, continues the work started with #507.

How has this been tested (if it applies)

  • test_sinkhorn_l1l2_transport_class test doesn't skip JAX backend
  • the test also updated to check semi-supervised use case
  • additional test cases for label_normalization and labels_to_masks helpers

PR checklist

  • I have read the CONTRIBUTING document.
  • The documentation is up-to-date with the changes I made (check build artifacts).
  • All tests passed, and additional code has been covered with new tests.
  • I have added the PR and Issue fix to the RELEASES.md file.

Additional Context

Working on the implementation I spotted the following (potential issue). It seems that the tests for semi-supervised DA, in fact, do not cover semi-supervised use case. They test the different between unsupervised (no labels for target) and supervised (target labels are available). For the test_sinkhorn_l1l2_transport_class specifically I did update the implementation to use partially masked labels for targets (see otda_semi). Does it covers the expected functionality correctly?

Copy link

codecov bot commented Dec 3, 2023

Codecov Report

Merging #587 (7774b16) into master (acd84ed) will increase coverage by 0.00%.
The diff coverage is 100.00%.

Additional details and impacted files
@@           Coverage Diff           @@
##           master     #587   +/-   ##
=======================================
  Coverage   96.74%   96.75%           
=======================================
  Files          77       77           
  Lines       15911    15939   +28     
=======================================
+ Hits        15393    15421   +28     
  Misses        518      518           

@rflamary rflamary merged commit 9ddb690 into PythonOT:master Dec 22, 2023
15 checks passed
@kachayev kachayev deleted the da-on-jax branch December 22, 2023 12:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants