Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-write HQQ Fused Gemm Transpose #185

Merged
merged 5 commits into from
May 3, 2024

Conversation

jeromeku
Copy link
Collaborator

@msaroufim

Re-write transposed fused gemm

  • Previous transpose implementation required weights to be unpacked / repacked.
  • This was particularly problematic for fine-tuning, since it introduced additional overhead either during the forward or backwards pass as noted by @mobicham in CUDA-MODE discord discussions.
  • Given that the intended use-case for the fused kernel is for fine-tuning, removing this overhead is essential.
  • The new version now enables the user to pack the weights once: the only change to run the transposed kernel is to pass transposed=True, no repacking needed for the weights or reshaping needed for scales, zeros.

Checks

  • All tests pass for the new transpose implementation
  • Benchmarks show no significant performance difference between transposed and non-transposed kernels

Changes

  • Rewrite transpose kernel implementation
  • Add docstring explaining the revised design.
  • Updated tests and benchmarks to check correctness and performance of transpose kernel.
  • Updated /prototype/hqq/README.md with Usage section.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 27, 2024
@mobicham
Copy link
Collaborator

mobicham commented Apr 28, 2024

Great job @jeromeku , in a full training setup, the kernel is 1.11x - 1.26x faster than dequantize() -> bfp16 torch.matmul (vanilla HQQ ref). This is in-line with the first benchmark I did on the layer level:

QLoRA fine-tuning (4-bit, group-size=None) | Llama2-7B | SPDA attention | GPU: A6000 ADA

batch_size=1, ctx_size=1024
HQQ ref (pad to ctx_size)                 : 2.06 it/sec | 18.9 GB 
HQQ ref (pad to ctx_size) + torch.compile : 2.33 it/sec | 18.6 GB -> 1.13x faster
Triton mm                                 : 2.60 it/sec | 18.8 GB -> 1.26x faster

batch_size=2, ctx_size=1024
HQQ ref (pad to ctx_size)                 : 1.17 it/sec | 32.2 GB 
HQQ ref (pad to ctx_size) + torch.compile : 1.32 it/sec | 31.0 GB -> 1.13x faster
Triton mm                                 : 1.33 it/sec | 31.8 GB -> 1.14x faster

batch_size=1, ctx_size=2048
HQQ ref (pad to ctx_size)                 : 1.07 it/sec | 33.3 GB
HQQ ref (pad to ctx_size) + torch.compile : 1.22 it/sec | 32.0 GB -> 1.14x faster
Triton mm                                 : 1.19 it/sec | 32.9 GB -> 1.11x faster 
  • This is a setup when all the prompts are longer or equal to ctx_size, so it depends on the datasets. Normally with a better sampling, you'd group prompts of similar size within the same batch and you only pad to the longest sequence, which could be 1.5x faster with HQQ ref for some samples. It is possible to mimic something like this with Triton mm by warming-up for a couple of pre-defined sizes [batch_size, 256], [batch_size, 512], etc... and pad to the closest one.

  • Is it possible to make it compatible with torch.compile()? Currently it crashes when I try to compile the model. HQQ ref with torch.compile(model) is faster and uses a bit less memory. I would expect the same behavior with Triton mm.
    torch_compile_tritonmm

  • fp8_fast_accum=True is 1.04 - 1.08x faster than fp8_fast_accum=False, and uses a bit less memory. But we will need to check the model quality after training vs. bfp16 accumulation. I wouldn't expect a huge drop in accuracy.

  • [x]  Benchmarks show no significant performance difference between transposed and non-transposed kernels
    

The older benchmark numbers are not consistent with the previous benchmark I did, I think it has to do with the shapes:
https://gist.github.com/mobicham/48e3ad537da6748e39d8a4ce27bfd612#file-triton_mm_benchmark-py-L154-L191

@msaroufim msaroufim merged commit be943a2 into pytorch:main May 3, 2024
15 checks passed
dbyoung18 pushed a commit to dbyoung18/ao that referenced this pull request Jul 31, 2024
* add updated kernel and tests

* update transpose benchmark

* add usage section to README

---------

Co-authored-by: Mark Saroufim <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants