Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a centered variance option to the ClippedAdam optimizer #3415

Open
wants to merge 5 commits into
base: dev
Choose a base branch
from

Conversation

BenZickel
Copy link
Contributor

@BenZickel BenZickel commented Jan 21, 2025

Problem

When using the ClippedAdam optimizer with highly imbalanced parameter gradients stability, the convergence rate of parameters with stable gradients is slower than what it could be.

Solution

Add an option to use the centered variance in the denominator of the step size calculation. Parameters with stable gradients will have a lower centered variance, than the current uncentered variance, and therefore will have a larger step size and higher convergence rate.

Testing

The improvement in convergence rate is shown below (taken from the test function run with plotting enabled):

  • The first plot shows the number of iterations needed in order to reach convergence, where convergence is defined as the ultimate loss plus a small threshold.
  • The second plot shows the convergence rate which is the mean per iteration improvement of the gap between the loss and the ultimate loss, which is roughly proportional to the inverse of the number of iterations needed in order to reach convergence. One can also notice that the convergence rate is less sensitive to changes in the learning rate when using the centered variance, compared to the uncentered variance case.
  • The third plot shows the ultimate loss reached, which shows that for regular Adam with uncentered variance the best possible loss, which is zero in our case, is not attained for small learning rates, within the allotted number of iterations.
    image

@martinjankowiak
Copy link
Collaborator

@BenZickel can you please explain your figure? i don't know how a convergence rate is computed, and i can't tell if the differences in the second plot are significant given the scale

@martinjankowiak
Copy link
Collaborator

a bit of googling led me here. the same algo in essence? a 2 second scan suggests they do bias correction

https://edoc.hu-berlin.de/server/api/core/bitstreams/14960a8d-4c35-4d08-86d7-1e130ecd42c8/content

@BenZickel
Copy link
Contributor Author

Thanks for the review @martinjankowiak.

  • I've added an additional figure and explanations about the plots. I think the first plot is the most important, as it shows the number of iterations needed in order to reach convergence.
  • Regarding the reference you provided, the algorithm described is indeed the same so I've added it to the references section. They describe the algorithm as using the true variance instead of the second moment of the gradient ("One has to note that Adam does not actually calculate the variance but the second moment instead. If E[gt] = 0, both definitions are identical, and fewer operations are required to calculate the second moment."). As for bias correction, they reach the conclusion that it does not need to change ("Finally, from the bias-corrected expression for vt, we conclude that the bias correction for cAdam is the same as that of Adam.").

@BenZickel
Copy link
Contributor Author

I've added the option to use the centered variance option in the Latent Dirichlet Allocation (LDA) example. When running some tests I've noticed that the centered variance option improve both the convergence rate, and ultimate loss, for a wide range of learning rates. Additionally, the same phenomena, as seen in the above example, of reduced sensitivity of the convergence rate to changes in the learning rate, can be observed when using the centered variance option.

The centered variance option in the LDA example can be used by running

python pyro-ppl/examples/lda.py -cv True


Small modification to the Adam algorithm implemented in torch.optim.Adam
to include gradient clipping and learning rate decay.
to include gradient clipping and learning rate decay and an option to use
the centered variance.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you point to the ref here?

@@ -435,3 +435,105 @@ def step(svi, optimizer):
actual.append(step(svi, optimizer))

assert_equal(actual, expected)


def test_centered_clipped_adam(plot_results=False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how long does this test take?

loss_vec.append(loss)
return torch.Tensor(loss_vec)

def calc_convergence(loss_vec, tail_len=100, threshold=0.01):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment what is being computed?

convergence_rate = (convergence_vec[:-1] / convergence_vec[1:]).log().mean()
return ultimate_loss, convergence_rate, convergence_iter

def get_convergence_vec(lr_vec, centered_variance):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment what is being computed?

@martinjankowiak
Copy link
Collaborator

thanks @BenZickel ! the motivation makes sense, and i can imagine how this might help, though i'm perhaps somewhat surprised by the size of the effect, though i guess your w has quite a range in magnitude, perhaps a larger range than we might expect to see in most scenarios

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants