Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switching to Torch 2.0 by default. #1922

Merged
merged 16 commits into from
May 17, 2023
Merged

Conversation

jkulhanek
Copy link
Contributor

This PR drops support for torch 1.12 and adds support for torch 2.0. Can we open a discussion if this is a good thing to do or not?

Note: we can try torch 2.0 compile to see if there are any speed improvements to be gained.

@SauravMaheshkar SauravMaheshkar added enhancement New feature or request speedup dependencies Pull requests that update a dependency file python Pull requests that update Python code labels May 15, 2023
LRScheduler,
)
except ImportError:
# Backward compatibility for PyTorch 1.x
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This offers backward compatibility for PyTorch 1.x. Should we keep it here or drop it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets keep it for now, but print a warning to the user that they should update to pytorch 2.0

@SauravMaheshkar SauravMaheshkar requested a review from tancik May 16, 2023 08:48
Dockerfile Outdated Show resolved Hide resolved
@igozali
Copy link

igozali commented May 16, 2023

Curious if this PR would include support for CUDA 12 as well?

@jkulhanek
Copy link
Contributor Author

Sure. For now only if you compile PyTorch from source: https://discuss.pytorch.org/t/pytorch-for-cuda-12/169447

@tancik
Copy link
Contributor

tancik commented May 16, 2023

Sure. For now only if you compile PyTorch from source: https://discuss.pytorch.org/t/pytorch-for-cuda-12/169447

Not sure if Tinycudann and nerfacc would work though.

@tancik
Copy link
Contributor

tancik commented May 16, 2023

@tancik
Copy link
Contributor

tancik commented May 16, 2023

The sphinx error is /home/docs/checkouts/readthedocs.org/user_builds/plenoptix-nerfstudio/checkouts/1922/docs/quickstart/installation.md:199: WARNING: 'myst' reference target not found: tiny-cuda-syntax-error

@jkulhanek jkulhanek force-pushed the jkulhanek/switch-to-newer-torch branch from a1f88d4 to 40775c5 Compare May 16, 2023 18:38
@tancik
Copy link
Contributor

tancik commented May 16, 2023

There appears to be an issue with instant-ngp, it is running much slower, but the quality is equivalent. Maybe @liruilong940607 has an idea?

x-axis seconds:
image

x-axis step:
image

@liruilong940607
Copy link
Contributor

Interesting, is this with torch2.0 + torch.compile() or simply switching to torch2.0?

@tancik
Copy link
Contributor

tancik commented May 16, 2023

Interesting, is this with torch2.0 + torch.compile() or simply switching to torch2.0?

Just switching to torch 2.0

@jkulhanek
Copy link
Contributor Author

I am installing a fresh environment to test it.

@jkulhanek
Copy link
Contributor Author

torch.compile breaks in @time_function wrapper.

@liruilong940607
Copy link
Contributor

liruilong940607 commented May 17, 2023

The poster scene I was testing on also has the distortion parameters. Which scene were you running on? I can also test that out.

@liruilong940607
Copy link
Contributor

It is possible that the JIT of that function makes timing weird. What if you just disable the JIT (also do not do torch compile)? I'm not familiar with either but at least we can know if it has something to do with them.

@tancik
Copy link
Contributor

tancik commented May 17, 2023

The poster scene I was testing on also has the distortion parameters. Which scene were you running on? I can also test that out.

https://data.nerf.studio/nerfstudio/floating-tree.zip

@liruilong940607
Copy link
Contributor

The poster scene I was testing on also has the distortion parameters. Which scene were you running on? I can also test that out.

https://data.nerf.studio/nerfstudio/floating-tree.zip

Ok I can reproduce this slowness with floating tree. The GPU stats are very different between these two runs.

Screen Shot 2023-05-17 at 10 57 31 AM Screen Shot 2023-05-17 at 10 57 17 AM

@tancik
Copy link
Contributor

tancik commented May 17, 2023

I did some rough timings for 2K iters

model opt method 1.13 time (sec) 2.0 time (sec)
nerfacto jit script 66 73
nerfacto compile - 74
nerfacto none 64 71
ingp jit script 66 🐢 350 iters after 60sec
ingp compile - 🐢 10 iters after 60sec
ingp none 58 60

Interestingly this suggests the we shouldn't use any JIT, which is odd since when we added JIT it was a big speed boost. Maybe it is because my training is CPU limited? Can someone else test how the speed compares when not using JIT?

@jkulhanek
Copy link
Contributor Author

jkulhanek commented May 17, 2023

I have tested 2K iters on A100, CUDA 11.8, PyTorch 2.0.1:

configuration train time
torch.jit.script 300s
eager mode 105s
torch.compile 490s
torch.compile(dynamic=True) 101s
torch.compile(dynamic=True, mode="reduce-overhead") 93s

the initial compilation takes sooo long

@jkulhanek
Copy link
Contributor Author

jkulhanek commented May 17, 2023

Actually for me torch.compile is the fastest. But the compilation takes a long time at the beginning:
image

...but first iterations:
image

@tancik
Copy link
Contributor

tancik commented May 17, 2023

So is eager the way to go for now? @jkulhanek were these tests with the ingp and the floating tree dataset?

@jkulhanek
Copy link
Contributor Author

@tancik , can you please try @torch.compile(dynamic=True) . This is the fastest for me now. I will still play with the configuration for a bit, but I believe this could be it.

@jkulhanek
Copy link
Contributor Author

So is eager the way to go for now? @jkulhanek were these tests with the ingp and the floating tree dataset?

Yes, ingp with floating tree dataset. I wonder if A100 is so much slower compared to 4090 or what is the problem?

@jkulhanek
Copy link
Contributor Author

jkulhanek commented May 17, 2023

Alternatively there is also mode="reduce-overhead". I would need to do more precise profiling, and there is the tradeoff if we want to spend more time on compile at the beginning of training or have slower iterations.

I suggest we stick with torch.compile(dynamic=True, mode="reduce-overhead") for now (see the table).

@liruilong940607
Copy link
Contributor

I can confirm that torch.compile(dynamic=True, mode="reduce-overhead") is the best option:

Screen Shot 2023-05-17 at 12 41 02 PM

Additionally I tested using nerfacc.cameras.opencv_lens_undistortion() to replace the radial_and_tangential_undistort(), which basically fuse all computation into a single CUDA kernel. It seems like the torch.compile(dynamic=True, mode="reduce-overhead") give similar speed with nerfacc's explicit fuse, only with a little bit compiling overhead in the first place (nerfacc needs compiling too but did not reveal here):

Screen Shot 2023-05-17 at 12 45 12 PM

Test with INGP, TITAN RTX, CUDA 11.7, 5000 iters, floating-tree scene.

@jkulhanek
Copy link
Contributor Author

Cool! @liruilong940607 can you please also try with additional backend=“eager”? That setup is the fastest for me.

@liruilong940607
Copy link
Contributor

liruilong940607 commented May 17, 2023

torch.compile(dynamic=True, mode="reduce-overhead", backend="eager") is not that good for me. (The compile-eager line in the plot)
Screen Shot 2023-05-17 at 1 02 59 PM
Screen Shot 2023-05-17 at 1 06 38 PM

@tancik
Copy link
Contributor

tancik commented May 17, 2023

Here are some results on nerfacto:
image

@torch.compile(dynamic=True, mode="reduce-overhead", backend="eager")
@torch.compile(dynamic=True, mode="reduce-overhead")
@torch.compile(dynamic=True, backend="eager")
@torch.compile(dynamic=False, mode="reduce-overhead", backend="eager")
main branch, pytorch 1.13

They are all basically the same, with the exception of @torch.compile(dynamic=True, mode="reduce-overhead") which is worse.

@liruilong940607
Copy link
Contributor

Oh I didn't know torch 1.13 also supports torch.compile. Isn't it a feature introduce in torch 2.0?

@tancik
Copy link
Contributor

tancik commented May 17, 2023

Oh I didn't know torch 1.13 also supports torch.compile. Isn't it a feature introduce in torch 2.0?

It doesn't I'm just plotting the main branch for reference (It uses @torch.jit.script)

@tancik
Copy link
Contributor

tancik commented May 17, 2023

Here are the same plots for ingp

image

In this case @torch.compile(dynamic=True, mode="reduce-overhead") starts off slow but ends up being the fastest.

@liruilong940607
Copy link
Contributor

I see. Does this mean torch.compile behaves basically the same with jit.script, in the case of static shape input? And they only behave differently for dynamic shape (and that’s when all these argument starts to matter)?

I’m curious what’s the logic behind different choices. Some toy examples might be helpful. (Or is there a tutorial somewhere to explain these things?)

Seems like we are ending up in a situation that different model (or even data?) should have different optimization strategy, which kinda makes sense to me.

@jkulhanek
Copy link
Contributor Author

@jkulhanek
Copy link
Contributor Author

It would be cool if we could then torch.compile the whole model to see if we can speed it up. Currently, there are several places where it breaks for nerfacto and other models. I expected the compilation to take much longer, but if it is too much we can cache the compiled model to drive (currently API quite experimental https://pytorch.org/get-started/pytorch-2.0/#user-experience ) so this will only happen once - same as nerfacc kernels I guess.

Copy link
Contributor

@tancik tancik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@tancik tancik merged commit e898f56 into main May 17, 2023
@tancik tancik deleted the jkulhanek/switch-to-newer-torch branch May 17, 2023 23:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
dependencies Pull requests that update a dependency file enhancement New feature or request python Pull requests that update Python code speedup
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants