-
Notifications
You must be signed in to change notification settings - Fork 506
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] OT for Gaussian distributions #428
Conversation
Co-authored-by: Alexandre Gramfort <[email protected]>
Co-authored-by: Alexandre Gramfort <[email protected]>
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## master #428 +/- ##
==========================================
+ Coverage 94.26% 94.43% +0.17%
==========================================
Files 23 24 +1
Lines 6204 6254 +50
==========================================
+ Hits 5848 5906 +58
+ Misses 356 348 -8 |
ot/da.py
Outdated
@@ -679,112 +680,12 @@ def df(G): | |||
return G, L | |||
|
|||
|
|||
@deprecated() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line is enough since teh API did not change, also you did not apss the parameter to the function so it was false
@deprecated() | |
OT_mapping_linear=deprecated(empirical_bures_wasserstein_mapping) |
ot/gaussian.py
Outdated
Cs12 = nx.sqrtm(Cs) | ||
|
||
B = nx.trace(Cs + Ct - 2 * nx.sqrtm(dots(Cs12, Ct, Cs12))) | ||
W = nx.norm(ms - mt) + B |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
W = nx.norm(ms - mt) + B | |
W = nx.sqrt(nx.norm(ms - mt)**2 + B) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
norm shuould be squared and the Bures wassresein is the quare ropot of the term
ot/gaussian.py
Outdated
Cs12 = nx.sqrtm(Cs) | ||
|
||
B = nx.trace(Cs + Ct - 2 * nx.sqrtm(dots(Cs12, Ct, Cs12))) | ||
W = nx.norm(mxs - mxt) + B |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same error here, you should call the function bures_wassrestein_distance
here anyways
Types of changes
This introduce Gaussian modules, a file comprising all the modules of optimal transport for Gaussian distributions.
The OT_mapping_linear moved from da.py to gaussian.py.
I added Bures Wasserstein distance to gaussian.py.
Also update examples/gromov/plot_barycenter_fgw.py to fit the new networkx API
Motivation and context / Related issue
Add new methods for Gaussian distributions.
How has this been tested (if it applies)
The test for Bures Wasserstein distance is done on two 1D Gaussian distributions centered on 0 and 10 respectively. with the same variance The Bures Wasserstein distance is tested to be close of 10.
PR checklist