-
-
Notifications
You must be signed in to change notification settings - Fork 52
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implementing linear conjugate gradients #62
Conversation
@bwengals Ping, can you take a look at the PR please :) |
will do. It might take me a couple days, it's a lot to go through, but it looks like you got it working, congrats! |
pymc_experimental/utils/linear_cg.py
Outdated
|
||
result = np.copy(initial_guess) | ||
|
||
if np.allclose(residual, residual): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a heads up, I noticed this check fail yesterday. Was passing when I ran it on Colab, looking into it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed it should have been if not np.allclose(...)
pymc_experimental/utils/linear_cg.py
Outdated
Setting = namedtuple("setting", "on") | ||
|
||
|
||
class settings: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since there are only two parameters here, could you pass eval_cg_tolerance
and cg_tolerance
as args to linear_cg
? I think it would simplify things a little bit, since you wouldn't need this class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This class is just a place holder, I would ideally put these two values in some sort of default settings file or class used by PyMC. I moved the constants to be global variables and kept the logic same as in GPyTorch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's no need to follow GPyTorch's particular way of doing things -- in PyMC there aren't any global variables or settings files, that I'm aware of. It's important to be consistent with these sort of things across the codebase.
pymc_experimental/utils/linear_cg.py
Outdated
linear_cg_retvals = linear_cg_updates( | ||
result, | ||
alpha, | ||
residual_inner_prod, | ||
eps, | ||
beta, | ||
residual, | ||
precond_residual, | ||
curr_conjugate_vec, | ||
) | ||
|
||
( | ||
result, | ||
alpha, | ||
residual_inner_prod, | ||
eps, | ||
beta, | ||
residual, | ||
precond_residual, | ||
curr_conjugate_vec, | ||
) = linear_cg_retvals |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just a small nitpick, but you could combine this, so you don't need to create the linear_cg_retvals
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in the latest commit
It would be nice if you included your example code to test this out in the actual PR somewhere! You could make a brief notebook that demonstrates how to use it. I left some minor comments, but I think shouldn't hold up the merge process. Depending on the strategy, this implementation will be translated to aesara/numba/jax, so the goal of this code should be clarity. Feel free to make changes and merge it when you're ready. |
BTW is this still PR still in draft mode, or do you think you're pretty much done? |
…n identity matrix is returned which doesn't affect downstream computation
Nope, this is ready to check-in now :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delay, everything looks copacetic
This is a draft implementation of the linear conjugate gradients algorithm.
It is designed to be output the same as
gpytorch.util.linear_cg
and can be tested to give the same results using the code snippet below.Let's first generate the data:
Now we get the output from GPyTorch
And also from our implementation:
Check the outputs