Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add roofline estimation of float8 gemm + overhead #668

Merged
merged 1 commit into from
Aug 15, 2024
Merged

Conversation

vkuzo
Copy link
Contributor

@vkuzo vkuzo commented Aug 13, 2024

Summary:

This PR adds a script to help estimate when it is expected to be faster to convert a torch.nn.Linear to float8, versus leaving it in bfloat16, for compute bound training use cases.

How we do this:

  1. for bf16 and fp8 gemm time, use either benchmarks or roofline based on a percentage of peak compute bandwidth
  2. for float8 overhead, enumerate the worst case of reads and writes performed by torch.compile, and convert to time by assuming we can achieve a percentage of peak memory bandwidth
  3. compare (a) bf16_gemm_time and (b) fp8_gemm_time + fp8_overhead_time, and calculate the expected speedup.

Note that currently the gemm benchmarks require running a separate script, as documented in the argument descriptions. We can make this nicer at a future time.

Test Plan:

python benchmarks/float8/float8_roofline.py \
    --outfile ~/local/tmp/20240813_del_del_dyn_bench_roofline.csv \
    --gemm_benchmarks_file ~/local/tmp/20240813_gemm_gpu_time_sweep_9_16.csv \
    --gemm_time_strategy benchmarks --model_torch_compile_limitations True \
    --scaling_type_input delayed --scaling_type_weight delayed \
    --scaling_type_grad_output dynamic | with-proxy gh gist create
- Creating gist...
✓ Created secret gist
https://gist.github.com/vkuzo/e3c4b274493140a7423c85d27863663a

Meta-only spreadsheet with more analysis, we should eventually OSS a vesion of this: https://docs.google.com/spreadsheets/d/1BpgGQjJwSmGen2QHukMmoCl7Ra5bgxAEQLEI0fma-Fs/edit?gid=195198894#gid=195198894

My tl;dr; from the analysis above:

  1. the largest predictor of speedups of float8 gemms is M, K, N (not surprising).
  2. handwavy estimates of where we should focus torch.compile + float8 improvements for per-tensor scaling:
    a. getting torch.compile to be optimal for dynamic scaling: ~1.05x per-linear speedup vs dynamic worst-case
    b. getting torch.compile to be optimal for delayed scaling: ~1.13x per-linear speedup vs delayed worst-case
    c. after (a) and (b) are done, delayed scaling is expected to be 1.06x faster, per-linear, for large shapes
  3. with current torch.compile behavior and for equal M, K, N, the breakeven point of float8 vs bf16 is somewhere between 2048 and 4096. If we get to best case behavior, the breakeven point would move to slightly below 2048. Of course, this is shape dependent.
  4. this script should give a tighter estimate of which production model layers could benefit from float8.
  5. with the current behavior of torch.compile, we don't expect speedups from delayed scaling vs delayed scaling. Fixes in compile are needed to see more of benefit here.

Reviewers:

Subscribers:

Tasks:

Tags:

Summary:

This PR adds a script to help estimate when it is expected
to be faster to convert a `torch.nn.Linear` to float8, versus
leaving it in bfloat16, for compute bound training use cases.

How we do this:
1. for bf16 and fp8 gemm time, use either benchmarks or roofline based
   on a percentage of peak compute bandwidth
2. for float8 overhead, enumerate the worst case of reads and writes performed
   by torch.compile, and convert to time by assuming we can achieve a
   percentae of peak memory bandwidth
3. compare (a) bf16_gemm_time and (b) fp8_gemm_time + fp8_overhead_time,
   and calculate the expected speedup.

Note that currently the gemm benchmarks require running a separate
script, as documented in the argument descriptions.  We can make this
nicer at a future time.

Test Plan:

```
python benchmarks/float8/float8_roofline.py \
    --outfile ~/local/tmp/20240813_del_del_dyn_bench_roofline.csv \
    --gemm_benchmarks_file ~/local/tmp/20240813_gemm_gpu_time_sweep_9_16.csv \
    --gemm_time_strategy benchmarks --model_torch_compile_limitations True \
    --scaling_type_input delayed --scaling_type_weight delayed \
    --scaling_type_grad_output dynamic | with-proxy gh gist create
- Creating gist...
✓ Created secret gist
https://gist.github.com/vkuzo/e3c4b274493140a7423c85d27863663a
```

Reviewers:

Subscribers:

Tasks:

Tags:
Copy link

pytorch-bot bot commented Aug 13, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/668

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 9967622 with merge base 88a263a (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 13, 2024
Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noob question does sympy really speed up this script?

@vkuzo
Copy link
Contributor Author

vkuzo commented Aug 15, 2024

Noob question does sympy really speed up this script?

Not sure about speedup. The reason I used sympy is to easily see the contribution of arbitrary M, K, N to memory overhead.

@vkuzo vkuzo merged commit 027bf39 into main Aug 15, 2024
14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants