Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sinkhorn l1lp transport does not return "log" details #412

Closed
sayali7 opened this issue Dec 5, 2022 · 1 comment · Fixed by #413
Closed

Sinkhorn l1lp transport does not return "log" details #412

sayali7 opened this issue Dec 5, 2022 · 1 comment · Fixed by #413

Comments

@sayali7
Copy link

sayali7 commented Dec 5, 2022

Describe the bug

when you initialize and run ot.da.SinkhornLpl1Transport() with log=True, and then the fit() function, I get an error that says "ValueError: too many values to unpack (expected 2)".

To Reproduce

Steps to reproduce the behavior:
I ran the following two lines of code:
1.ot_lpl1 = ot.da.SinkhornLpl1Transport(reg_e=1e02,reg_cl=1e-2,log=True,verbose=True)
2. ot_lpl1.fit(Xs=Xs, ys=ys, Xt=Xt)

Error Message:


ValueError Traceback (most recent call last)
/tmp/ipykernel_176879/2895611854.py in
1 # Sinkhorn Transport with Group lasso regularization
2 ot_lpl1 = ot.da.SinkhornLpl1Transport(reg_e=1e02,reg_cl=1e-2,log=True,verbose=True)
----> 3 ot_lpl1.fit(Xs=Xs, ys=ys, Xt=Xt)
4 transp_Xs_lpl1 = ot_lpl1.transform(Xs=Xs)
5 pd.DataFrame(transp_Xs_lpl1).head()

~/.local/lib/python3.8/site-packages/ot/da.py in fit(self, Xs, ys, Xt, yt)
1748 # deal with the value of log
1749 if self.log:
-> 1750 self.coupling_, self.log_ = returned_
1751 else:
1752 self.coupling_ = returned_

ValueError: too many values to unpack (expected 2)

Code sample

ot_lpl1 = ot.da.SinkhornLpl1Transport(reg_e=1e02,reg_cl=1e-2,log=True,verbose=True)
ot_lpl1.fit(Xs=Xs, ys=ys, Xt=Xt)

Expected behavior

The fit() function should return an object containing both "coupling" and "log" details.

Environment (please complete the following information):

  • OS (e.g. MacOS, Windows, Linux): Ubuntu 20
  • Python version: 3.8
  • How was POT installed (source, pip, conda): pip

Output of the following code snippet:

import platform; print(platform.platform())
import sys; print("Python", sys.version)
import numpy; print("NumPy", numpy.__version__)
import scipy; print("SciPy", scipy.__version__)
import ot; print("POT", ot.__version__)

Linux-5.4.0-132-generic-x86_64-with-glibc2.29
Python 3.8.10 (default, Jun 22 2022, 20:18:18)
[GCC 9.4.0]
NumPy 1.21.4
SciPy 1.8.0
POT 0.8.1.0

Additional context

@rflamary
Copy link
Collaborator

rflamary commented Dec 6, 2022

Good catch!

I fixed it in PR #413 that shoud be merged soon. You can use branch bug_log_lpl1 in the meantime.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants