Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing linear conjugate gradients #62

Merged
merged 12 commits into from
Oct 1, 2022
Merged

Conversation

kunalghosh
Copy link
Contributor

@kunalghosh kunalghosh commented Aug 2, 2022

This is a draft implementation of the linear conjugate gradients algorithm.
It is designed to be output the same as gpytorch.util.linear_cg and can be tested to give the same results using the code snippet below.

Let's first generate the data:

import torch 
import gpytorch
import numpy as np

# Test PSD Matrix

N = 10
rank = 5
np.random.seed(1234) # nans with seed 1234
K = np.random.randn(N, N)
K = K @ K.T + N * np.eye(N)
K_torch = torch.from_numpy(K)
print(np.diag(K))

y = np.random.randn(N, 1)
y_torch = torch.from_numpy(y)

Now we get the output from GPyTorch

result, tridiag = gpytorch.utils.linear_cg(K_torch, y_torch, n_tridiag=1)

And also from our implementation:

import pymc_experimental as pymx
result_pmx, tridiag_pmx = pymx.utils.linear_cg(K, y, n_tridiag=1)

Check the outputs

assert np.allclose(result, result_pmx) is True, "BUG: result doesn't match gpytorch values"
assert np.allclose(tridiag tridiag_pmx) is True, "BUG: tridiagonal values don't match gpytorch values"

@kunalghosh kunalghosh marked this pull request as draft August 2, 2022 15:07
@kunalghosh
Copy link
Contributor Author

@bwengals Ping, can you take a look at the PR please :)

@bwengals
Copy link
Contributor

bwengals commented Aug 3, 2022

will do. It might take me a couple days, it's a lot to go through, but it looks like you got it working, congrats!


result = np.copy(initial_guess)

if np.allclose(residual, residual):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a heads up, I noticed this check fail yesterday. Was passing when I ran it on Colab, looking into it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed it should have been if not np.allclose(...)

Setting = namedtuple("setting", "on")


class settings:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since there are only two parameters here, could you pass eval_cg_tolerance and cg_tolerance as args to linear_cg? I think it would simplify things a little bit, since you wouldn't need this class.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class is just a place holder, I would ideally put these two values in some sort of default settings file or class used by PyMC. I moved the constants to be global variables and kept the logic same as in GPyTorch.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no need to follow GPyTorch's particular way of doing things -- in PyMC there aren't any global variables or settings files, that I'm aware of. It's important to be consistent with these sort of things across the codebase.

Comment on lines 169 to 189
linear_cg_retvals = linear_cg_updates(
result,
alpha,
residual_inner_prod,
eps,
beta,
residual,
precond_residual,
curr_conjugate_vec,
)

(
result,
alpha,
residual_inner_prod,
eps,
beta,
residual,
precond_residual,
curr_conjugate_vec,
) = linear_cg_retvals
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just a small nitpick, but you could combine this, so you don't need to create the linear_cg_retvals.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in the latest commit

@bwengals
Copy link
Contributor

bwengals commented Aug 6, 2022

It would be nice if you included your example code to test this out in the actual PR somewhere! You could make a brief notebook that demonstrates how to use it.

I left some minor comments, but I think shouldn't hold up the merge process. Depending on the strategy, this implementation will be translated to aesara/numba/jax, so the goal of this code should be clarity. Feel free to make changes and merge it when you're ready.

@bwengals
Copy link
Contributor

bwengals commented Aug 6, 2022

BTW is this still PR still in draft mode, or do you think you're pretty much done?

@kunalghosh
Copy link
Contributor Author

BTW is this still PR still in draft mode, or do you think you're pretty much done?

Nope, this is ready to check-in now :)

@kunalghosh kunalghosh marked this pull request as ready for review September 27, 2022 17:17
@twiecki twiecki requested a review from bwengals September 28, 2022 08:49
Copy link
Contributor

@bwengals bwengals left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay, everything looks copacetic

@twiecki twiecki merged commit 0480f95 into pymc-devs:main Oct 1, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants