-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Switching to Torch 2.0 by default. #1922
Conversation
LRScheduler, | ||
) | ||
except ImportError: | ||
# Backward compatibility for PyTorch 1.x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This offers backward compatibility for PyTorch 1.x. Should we keep it here or drop it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lets keep it for now, but print a warning to the user that they should update to pytorch 2.0
Curious if this PR would include support for CUDA 12 as well? |
Sure. For now only if you compile PyTorch from source: https://discuss.pytorch.org/t/pytorch-for-cuda-12/169447 |
Not sure if Tinycudann and nerfacc would work though. |
Test renders using nerfacto: No noticeable quality differences. |
The sphinx error is |
a1f88d4
to
40775c5
Compare
There appears to be an issue with |
Interesting, is this with torch2.0 + |
Just switching to torch 2.0 |
I am installing a fresh environment to test it. |
torch.compile breaks in @time_function wrapper. |
The poster scene I was testing on also has the distortion parameters. Which scene were you running on? I can also test that out. |
It is possible that the JIT of that function makes timing weird. What if you just disable the JIT (also do not do torch compile)? I'm not familiar with either but at least we can know if it has something to do with them. |
|
I did some rough timings for 2K iters
Interestingly this suggests the we shouldn't use any JIT, which is odd since when we added JIT it was a big speed boost. Maybe it is because my training is CPU limited? Can someone else test how the speed compares when not using JIT? |
I have tested 2K iters on A100, CUDA 11.8, PyTorch 2.0.1:
the initial compilation takes sooo long |
So is eager the way to go for now? @jkulhanek were these tests with the |
@tancik , can you please try @torch.compile(dynamic=True) . This is the fastest for me now. I will still play with the configuration for a bit, but I believe this could be it. |
Yes, ingp with floating tree dataset. I wonder if A100 is so much slower compared to 4090 or what is the problem? |
Alternatively there is also I suggest we stick with |
I can confirm that ![]() Additionally I tested using ![]() Test with INGP, TITAN RTX, CUDA 11.7, 5000 iters, floating-tree scene. |
Cool! @liruilong940607 can you please also try with additional backend=“eager”? That setup is the fastest for me. |
Oh I didn't know torch 1.13 also supports |
It doesn't I'm just plotting the main branch for reference (It uses @torch.jit.script) |
I see. Does this mean torch.compile behaves basically the same with jit.script, in the case of static shape input? And they only behave differently for dynamic shape (and that’s when all these argument starts to matter)? I’m curious what’s the logic behind different choices. Some toy examples might be helpful. (Or is there a tutorial somewhere to explain these things?) Seems like we are ending up in a situation that different model (or even data?) should have different optimization strategy, which kinda makes sense to me. |
Here is some nice info: |
It would be cool if we could then torch.compile the whole model to see if we can speed it up. Currently, there are several places where it breaks for nerfacto and other models. I expected the compilation to take much longer, but if it is too much we can cache the compiled model to drive (currently API quite experimental https://pytorch.org/get-started/pytorch-2.0/#user-experience ) so this will only happen once - same as nerfacc kernels I guess. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This PR drops support for torch 1.12 and adds support for torch 2.0. Can we open a discussion if this is a good thing to do or not?
Note: we can try torch 2.0 compile to see if there are any speed improvements to be gained.