Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up FP6-LLM #304

Merged
merged 23 commits into from
Jun 9, 2024
Merged

Clean up FP6-LLM #304

merged 23 commits into from
Jun 9, 2024

Conversation

gau-nernst
Copy link
Collaborator

  • Remove original FP6 quantization code (qtorch and C++ bit-packing)
  • Replace FP32<->FP6 dtype conversion with @vkuzo's implementation for MX dtypes
    • I also migrate some of my FP32->FP6 rounding test cases to MX custom cast test.

Copy link

pytorch-bot bot commented Jun 3, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/304

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 6f8e7e9 with merge base 000a0fd (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 3, 2024
@vkuzo
Copy link
Contributor

vkuzo commented Jun 4, 2024

Replace FP32<->FP6 dtype conversion with @vkuzo's implementation for MX dtypes

I have two questions:

  1. Just curious, have you benchmarked the performance of this change? I have not optimized the mx types for performance yet.
  2. I think it's not ideal to depend on code in prototype folder, if other people need custom_cast.py it might be good to move it to a common place outside of the prototype folder.

@gau-nernst
Copy link
Collaborator Author

gau-nernst commented Jun 4, 2024

@vkuzo

  1. I did benchmarking. IIRC, your implementation is faster than mine for both CPU and GPU (w/ torch.compile). Will benchmark it again and update the result in this comment. (I also did correctness comparison. Your implementation and mine are bit-identical for all FP16 bit patterns - can't test all FP32 bit patterns because it would take too long)
  2. Yea I want to discuss with you about this also. It will be good to move it to a separate file for low bit-width floating point conversion. The FP6-LLM author added support for FP5 E2M2, so I want to support FP5 also (mx doesn't have FP5).

A question for @msaroufim. Is there a guideline when we should or should not decorate a function with @torch.compile? These functions rely on torch.compile to be fast, but there are cold-start (first time is slow) and dynamic shape (can we avoid re-compile on different shapes, in a guaranteed way?) problems. (haven't looked into the problems too much yet, maybe it is not that significant). (also saw @cpuhrsch comment on another PR about @torch.compile won't work on windows - perhaps we need a wrapper around torch.compile decorator? but windows is not officially supported for now I suppose)

Update: FP32->FP6_E3M2 (8192,8192) matrix (main branch) - benchmark with torch.utils.benchmark.Timer (so CPU is using only 1 thread) - CPU Ryzen 5600 and GPU 4070Ti SUPER

device mode op time (ms)
CPU eager _to_float6_e3m2_pt 1702.95
CPU eager f32_to_f6_e3m2_unpacked 1604.09
CPU compile _to_float6_e3m2_pt 445.011
CPU compile f32_to_f6_e3m2_unpacked 214.897
CPU C++ to_float6_e3m2_unpacked_cpu 360.433
CUDA eager _to_float6_e3m2_pt 13.4336
CUDA eager f32_to_f6_e3m2_unpacked 14.9207
CUDA compile _to_float6_e3m2_pt 0.578769
CUDA compile f32_to_f6_e3m2_unpacked 0.577399

CUDA is memory-bound so the implementation does not matter much (as long as it is correct). For CPU, your implementation is faster, especially with torch.compile (and faster than my C++ implementation). Though I found that CPU benchmark results tend to vary greatly across CPUs...

from functools import partial

import torch
import pandas as pd
from torch.utils.benchmark import Timer
from torchao.prototype.mx_formats.custom_cast import f32_to_f6_e3m2_unpacked
from torchao.dtypes.float6_e3m2 import _to_float6_e3m2_pt


def benchmark(f, *args):
    measurement = Timer(
        stmt="f(*args)",
        globals={"f": f, "args": args},
    ).blocked_autorange()
    return measurement.median * 1000


if __name__ == "__main__":
    M = 8192
    N = 8192
    fp32_weight = torch.randn(M, N)
    fp32_weight_cuda = fp32_weight.cuda()

    functions = [
        ("_to_float6_e3m2_pt", partial(_to_float6_e3m2_pt, no_bit_packing=True)),
        ("f32_to_f6_e3m2_unpacked", f32_to_f6_e3m2_unpacked),
    ]

    results = []
    for name, f in functions:
        results.append(["CPU", "eager", name, benchmark(f, fp32_weight)])
        results.append(["CUDA", "eager", name, benchmark(f, fp32_weight_cuda)])

        results.append(["CPU", "compile", name, benchmark(torch.compile(f), fp32_weight)])
        results.append(["CUDA", "compile", name, benchmark(torch.compile(f), fp32_weight_cuda)])

    df = pd.DataFrame(results, columns=["device", "mode", "op", "time (ms)"])
    df = df.sort_values(["device", "mode"], ascending=[True, False])
    print(df.to_markdown(index=False))

@msaroufim
Copy link
Member

So for Windows the main issue is torch.compile() codegenerates triton kernels which hasn't prioritized Windows support. I think for inductor cpu backend this should be less of an issue, I suspect there might be an overly aggressive assert somewhere though.

I would say overall everything should be compilable, the cold start problems is indeed annoying and is actively being worked, there are some broader plans that have been shared though https://dev-discuss.pytorch.org/t/how-to-bring-compile-time-down-to-zero-our-plans-and-direction-may-14th-edition/2089

Regarding dynamic shapes the way I iterate through things is first eliminate graph breaks then recompilations, this has been my goto guide https://github.com/pytorch/pytorch/blob/main/docs/source/torch.compiler_troubleshooting.rst

Also just FYI we removed the requirement to have branches up to date before merge, there was a breaking change in PyTorch that was just reverted so please rebase your changes to get rid of CI flakes

@gau-nernst gau-nernst marked this pull request as ready for review June 9, 2024 16:19
@gau-nernst gau-nernst requested a review from msaroufim June 9, 2024 16:19
@gau-nernst gau-nernst requested a review from vkuzo June 9, 2024 16:19
Copy link
Member

@msaroufim msaroufim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's a lot of deletions 🗡️

@msaroufim msaroufim merged commit cd8f647 into pytorch:main Jun 9, 2024
13 checks passed
msaroufim added a commit that referenced this pull request Jun 9, 2024
dbyoung18 pushed a commit to dbyoung18/ao that referenced this pull request Jul 31, 2024
* override load from state dict

* fix prefix

* migrate to mx primitive

* remove unneeded code

* comment out test

* remove

* add rounding test for f6_e3m2

* update tests

* remove openmp flag

* update benchmark script

* test negative number

* remove qtorch dep

* fix type casting

* add view

* fix strange pytest behavior

* only skip tests requiring PyTorch 2.4

* remove weight loading magic
yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
* eval and GPTQ work

Summary: fleshing out the eval code so it works reliably, adding ci,
adding gptq. fixed defaults for eval/gptq so they generally working
meaningfully without being specified. note, we need a better way to
save/load gptq models since they take so long to quantize. I tried using
.so but it doesn't seem to work reliably. also added eval and gptq to
ci.

Test Plan:

python eval.py --checkpoint-path checkpoints/$MODEL_REPO/model.pth \
  --device cuda --dtype bfloat16

python eval.py --checkpoint-path checkpoints/$MODEL_REPO/model.pth \
    --dtype bfloat16 --device cuda \
    --quant '{"linear:int4" : {"groupsize" : 32} }' \
    --compile

python eval.py --checkpoint-path checkpoints/$MODEL_REPO/model.pth \
    --dtype bfloat16 --device cuda \
    --quant '{"linear:int4" : {"groupsize" : 32} }'

python eval.py --checkpoint-path checkpoints/$MODEL_REPO/model.pth \
    --dtype bfloat16 --device cuda \
    --quant '{"linear:int4-gptq" : {"groupsize" : 32} }'

...running...

Reviewers:

Subscribers:

Tasks:

Tags:

* fix language in help doc

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* declare scales_and_zeros

---------

Co-authored-by: HDCharles <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants