Skip to content

Commit

Permalink
Applied PR #49 for issue #57.
Browse files Browse the repository at this point in the history
  • Loading branch information
jeanfeydy committed Jun 18, 2022
1 parent 54f0eb5 commit 69a3130
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions geomloss/ot/abstract_solvers/sinkhorn_ot.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ def sinkhorn_loop(
# "Super-efficiency of automatic differentiation for
# functions defined as a minimum", Ablin, Peyré, Moreau (2020)
# https://arxiv.org/pdf/2002.03722.pdf.
prev_autograd = torch.is_grad_enabled()
torch.autograd.set_grad_enabled(False)

# Line 1 (in Algorithm 3.6 from Jean Feydy's PhD thesis) ---------------------------
Expand Down Expand Up @@ -325,7 +326,7 @@ def sinkhorn_loop(
C_xx_fine, C_yy_fine = C_xxs[k + 1], C_yys[k + 1]

last_extrapolation = False # No need to re-extrapolate after the loop
torch.autograd.set_grad_enabled(True)
torch.autograd.set_grad_enabled(prev_autograd)

else: # It's worth investing some time on kernel truncation...
# The lines below implement the Kernel truncation trick,
Expand Down Expand Up @@ -411,7 +412,7 @@ def sinkhorn_loop(
# As detailed above (around "torch.autograd.set_grad_enabled(False)"),
# this allows us to retrieve correct expressions for the gradient
# without having to backprop through the whole Sinkhorn loop.
torch.autograd.set_grad_enabled(True)
torch.autograd.set_grad_enabled(prev_autograd)

if last_extrapolation:
# The cross-updates should be done in parallel!
Expand Down

0 comments on commit 69a3130

Please sign in to comment.