Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

An Error (RuntimeWarning: invalid value encountered in log) in ot.da.SinkhornL1lL2 after some iterations #311

Closed
EnayatAria opened this issue Nov 16, 2021 · 12 comments · Fixed by #312 or #313

Comments

@EnayatAria
Copy link

EnayatAria commented Nov 16, 2021

Using ot.da.SinkhornL1l2Transport for a domain adaptation problem, I faced an error as follows:

Datasets used:

Xs.txt
Xt.txt
ys.txt

** To Reproduce**

If you download the input files in C:\ , then the code is:

import numpy as np
import to

Xs = np.loadtxt("C: / Xs.txt").reshape(604, 5)
Xt = np.loadtxt("C: / Xt.txt").reshape(601, 5)
ys = np.loadtxt("C: / ys.txt")

ot_base = ot.da.SinkhornL1l2Transport(reg_e=10000, reg_cl=100, max_iter=100, verbose=True)
ot_base.fit(Xs=Xs, ys=ys, Xt=Xt)

Result and Error

 It.  |Loss        |Relative loss|Absolute loss
------------------------------------------------
0|5.193677e+06|0.000000e+00|0.000000e+00
1|3.150847e+05|1.548343e+01|4.878593e+06
2|2.668420e+05|1.807914e-01|4.824274e+04
3|2.663638e+05|1.795333e-03|4.782117e+02
4|2.663590e+05|1.786689e-05|4.759007e+00
5|2.663590e+05|1.312580e-07|3.496174e-02
6|2.663588e+05|7.339658e-07|1.954982e-01
7|2.663106e+05|1.808094e-04|4.815146e+01

C:\Users\enayat.aria\PycharmProjects\pythonProject\venv\lib\site-packages\ot\optim.py:357: RuntimeWarning: 
invalid value encountered in log
return np.sum(M * G) + reg1 * np.sum(G * np.log(G)) + reg2 * f(G)
Traceback (most recent call last):
File "C:\Program Files\JetBrains\PyCharm Community Edition 2021.1.2\plugins\python- 
ce\helpers\pydev\pydevd.py", line 1483, in _exec
pydev_imports.execfile(file, globals, locals)  # execute the script
File "C:\Program Files\JetBrains\PyCharm Community Edition 2021.1.2\plugins\python- 
ce\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "C:/Users/enayat.aria/PycharmProjects/pythonProject/OT_for_DA.py", line 258, in <module>
ot_base.fit(Xs=Xs, ys=ys, Xt=Xt)
File "C:\Users\enayat.aria\PycharmProjects\pythonProject\venv\lib\site-packages\ot\da.py", line 1950, in fit
returned_ = sinkhorn_l1l2_gl(
File "C:\Users\enayat.aria\PycharmProjects\pythonProject\venv\lib\site-packages\ot\da.py", line 239, in 
sinkhorn_l1l2_gl
return gcg(a, b, M, reg, eta, f, df, G0=None, numItermax=numItermax,
File "C:\Users\enayat.aria\PycharmProjects\pythonProject\venv\lib\site-packages\ot\optim.py", line 388, in gcg
G = G + alpha * deltaG
TypeError: unsupported operand type(s) for *: 'NoneType' and 'float'

Checking the parameters, I found that the G matrix obtained in the optim.py code has negative values in the last iteration; due to the last update.

Please let me know how to solve the problem, or if I should provide more information.

Best,

Environment (please complete the following information):

  • OS (e.g. MacOS, Windows, Linux): Windows 10
  • Python version: 3.9
  • How was POT installed (source, pip, conda): pip
@hichamjanati
Copy link
Contributor

@EnayatAria can you elaborate ? your 'to reproduce' section is missing

@EnayatAria EnayatAria changed the title Convergence problem in ot.da.SinkhornL1L2 An Error (RuntimeWarning: invalid value encountered in log) in ot.da.SinkhornL1lL2 after some iterations Nov 16, 2021
@EnayatAria
Copy link
Author

EnayatAria commented Nov 16, 2021 via email

@hichamjanati
Copy link
Contributor

can you provide a small code snippet to reproduce the error so i can debug it ?

@EnayatAria
Copy link
Author

The code is a simple call of the function

ot_base = ot.da.SinkhornL1l2Transport(reg_e=10000, reg_cl=100, max_iter=100, verbose=True)
ot_base.fit(Xs=Xs, ys=ys, Xt=Xt)

The error is not from the code, but from the input values; since it is working with another dataset. However, when I changed the data set it sends the error after a couple of iteration.

@EnayatAria
Copy link
Author

Datasets used:

Xs.txt
Xt.txt
ys.txt

The Xs, Ys, and Xt are attached. Xs = {ndarray: (604, 5)} ; Xt = {ndarray: (601, 5)} ; ys = {ndarray: (604, )}

Considering the dimensions the files should be reshaped such as:
original_Xs = np.loadtxt("Xs.txt").reshape(604, 5)

@EnayatAria
Copy link
Author

To reproduce, the code and the datasets are updated in the first comment.

@hichamjanati
Copy link
Contributor

These numerical errors are probably caused by the large values in your data, try normalizing them first or normalize the cost with the norm argument:

ot_base = ot.da.SinkhornL1l2Transport(reg_e=1, reg_cl=100, max_iter=100, verbose=True, norm="median")

hope this fixes your problem

@rflamary
Copy link
Collaborator

There should never be negative values in G along the optim iterations. It might come from the solver we should look into it.

@EnayatAria
Copy link
Author

EnayatAria commented Nov 17, 2021

Thank you for the solution, but it just breaks the iterations and returns the final G with negative values from optim.py.

I am wondering if the result would be the optimal solution for the general regularized OT problem.

'alpha' is 'None' when none of the conditions in the scalar_search_armijo function in lineserach.py is met. There is a comment above the last 'return' saying that # Failed to find a suitable step length.

What does that mean? Does it mean that it can not converge?

alpha values for the provided datasets returned from line_search_armijo in optim.py are

iter alpha
0 0.99
1 0.99
2 0.99
3 0.99
4 0.99
5 -340.96758804319694
6 -271.4559025260829
7 None

The code again sent the same error as already mentioned

....\site-packages\ot\optim.py:357: RuntimeWarning: invalid value encountered in log
return np.sum(M * G) + reg1 * np.sum(G * np.log(G)) + reg2 * f(G)

but this time does not stop.

@EnayatAria
Copy link
Author

EnayatAria commented Nov 18, 2021

Thanks for the modifications. It is now working till iteration 13 but now there is another error in backend.py and then it stops the iterations. The result and the error for the given dataset are as follows:

It. |Loss |Relative loss|Absolute loss

0|5.065656e+06|0.000000e+00|0.000000e+00
1|3.663805e+05|1.282621e+01|4.699275e+06
2|3.204103e+05|1.434732e-01|4.597028e+04
3|3.199586e+05|1.411632e-03|4.516638e+02
4|3.199541e+05|1.408355e-05|4.506090e+00
5|3.199540e+05|2.988819e-07|9.562846e-02
6|3.199540e+05|3.396901e-08|1.086852e-02
7|3.199539e+05|3.162696e-07|1.011917e-01
8|3.199539e+05|4.462186e-08|1.427694e-02
9|3.199539e+05|1.630207e-08|5.215909e-03
10|3.199539e+05|1.724603e-08|5.517934e-03
11|3.199538e+05|3.873306e-08|1.239279e-02
12|3.199538e+05|4.256753e-08|1.361964e-02
C:\Users\enayat.aria\PycharmProjects\pythonProject\venv\lib\site-packages\ot\backend.py:754: RuntimeWarning: invalid 
value encountered in log
return np.log(a)
13|3.199538e+05|0.000000e+00|0.000000e+00

Using the source codes, I have just updated utils.py and put backend.py in ot folder. (backend.py was not already in my ot folder). is it a bug? or should I update other codes?

@rflamary
Copy link
Collaborator

Please when testing use an environement and install directly from the master repository.

If you don't have backend it means that you are on an old version of POT

@EnayatAria
Copy link
Author

Thank you for your effort and responsibility. it is working now!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants