Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a prototype of MX format training and inference #264

Merged
merged 1 commit into from
May 28, 2024

Conversation

vkuzo
Copy link
Contributor

@vkuzo vkuzo commented May 23, 2024

Summary:

The MX numerical formats are new low precision formats with recent acceptance into the OCP spec:
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf

This PR adds a reference native PyTorch implementation of training and inference primitives for using MX accelerated matrix multiplications. Currently, we use a reference layout (scale and raw data stored separately) and an emulated matrix multiplication.

Test Plan:

// lint
lintrunner --configs .lintrunner.toml -a
// tests
pytest -s test/prototype/mx_formats/*
// benchmarks
python torchao/prototype/mx_formats/benchmarks/bench_qdq.py

Reviewers:

Subscribers:

Tasks:

Tags:

Copy link

pytorch-bot bot commented May 23, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/264

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 4425d0d with merge base 5b04ff0 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 23, 2024
@vkuzo
Copy link
Contributor Author

vkuzo commented May 23, 2024

need to add license and fix CI

Copy link
Member

@msaroufim msaroufim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

first round of feedback on docs and ci related stuff - will do another pass

test/prototype/mx_formats/test_mx_linear.py Show resolved Hide resolved
test/prototype/mx_formats/test_custom_cast.py Show resolved Hide resolved
torchao/prototype/mx_formats/fp_formats.py Outdated Show resolved Hide resolved
torchao/prototype/mx_formats/README.md Show resolved Hide resolved
```python
from torchao.prototype.mx_formats.mx_tensor import MXTensor
from torchao.prototype.mx_formats.constants import DTYPE_FP6_E2M3, DTYPE_FP6_E3M2, DTYPE_FP4
x = torch.randn(...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

put a functioning snippet that people can copy paste

```python
from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_linear

m = Model(...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment on a functional snippet

torchao/prototype/mx_formats/README.md Show resolved Hide resolved
torchao/prototype/mx_formats/README.md Show resolved Hide resolved
torchao/prototype/mx_formats/benchmarks/bench_qdq.py Outdated Show resolved Hide resolved
@vkuzo vkuzo force-pushed the 20240523_mx_formats_code_move branch 7 times, most recently from aef63f9 to 0c84b1a Compare May 24, 2024 15:56
@vkuzo
Copy link
Contributor Author

vkuzo commented May 24, 2024

ok, CI is green, going to address the other comments now

@msaroufim msaroufim self-requested a review May 24, 2024 16:37
msaroufim
msaroufim previously approved these changes May 24, 2024
Copy link
Member

@msaroufim msaroufim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stamping for now, will need a day or more to read the mx spec and don't wanna block your PR until then

@vkuzo
Copy link
Contributor Author

vkuzo commented May 24, 2024

Stamping for now, will need a day or more to read the mx spec and don't wanna block your PR until then

I can wait for review, would rather only land once people are ok with the code.

@vkuzo vkuzo requested a review from msaroufim May 24, 2024 16:40
@vkuzo vkuzo force-pushed the 20240523_mx_formats_code_move branch 2 times, most recently from 455f148 to 77541c5 Compare May 24, 2024 19:21
Copy link
Member

@msaroufim msaroufim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Really enjoyed reviewing this. some minor nits but we should be good to merge

README.md Outdated
@@ -99,6 +99,7 @@ To learn more try out our APIs, you can check out API examples in
3. Support for lower precision [dtypes](./torchao/dtypes) such as
- [nf4](https://github.com/pytorch/ao/blob/main/torchao/dtypes/nf4tensor.py) which was used to [implement QLoRA](https://github.com/pytorch/torchtune/blob/main/docs/source/tutorials/qlora_finetune.rst) without writing custom Triton or CUDA code
- [uint4](https://github.com/pytorch/ao/blob/main/torchao/dtypes/uint4.py)
- [MX](https://github.com/pytorch/ao/blob/main/torchao/prototype/mx_formats) implementing the [OCP MX spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf), prototype as the hardware support is not available yet
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth expanding to mention MX including fp8/6/4 and int8 - MX is still new terminology

torchao/prototype/mx_formats/constants.py Show resolved Hide resolved
import torch

# This is conceptually an enum of non-core dtypes
# if someone has time to verify torch.compile compatibility, it could be made
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the comment intending to say that torch.compile is breaking on enum or that in the future torch.compile support can be checked AND indepedently this could be made into an enum.

Indeed I feel like an enum would make this significantly easier to read cause you could conceptually print every row in the spec

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The weird thing is that some of these dtypes are in core (the float8 ones) and some aren't (float6/float4/mx's spec of int8, etc). I think it would be nice to have a clean structure unifying all of that, I just haven't had the time. Definitely open for someone (or future me) to improve this.

def compute_error(x, y):
Ps = torch.norm(x) # noqa: TOR101
Pn = torch.norm(x - y) # noqa: TOR101
return 20 * torch.log10(Ps / Pn)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's already a util for this exactl function in the code, somewhere in gptq IIRC so can we put this in torchao/utils.py instead?


### MXTensor

This is casts between fp32/bf16 and MX formats implemented in native PyTorch.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw the spec didn't seem too prescriptive around what the source dtype should be

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fp32 and bf16 is what we have today, we can make it clearer that other dtypes can be added in the future


def get_fp_scale(scale_e8m0):
s_offset = scale_e8m0.to(torch.int16) - E8M0_EXPONENT_BIAS
# TODO(later): it would be nice if there was a way to do the 2^x operation
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense! I will punt this to a future person, this shouldn't be that important for e2e performance.

return g, None, None


@torch._dynamo.allow_in_graph
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a public API torch.compiler.allow_in_graph - also curious why this was needed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is necessary for compile to fully support training, and this line is copy-pasta from float8_experimental, ideally while these two products are in different codebases I'm hoping for these kind of issues to get fixed in float8_experimental first and be copied here. Once we unify it will be easier.

return _f4_or_f6_unpacked_to_f32(x, DTYPE_FP6_E3M2)


if has_triton():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was the inductor codegen note adequate? Wondering if we can eventually remove this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

current codegen was slow, tracked in pytorch/pytorch#124002 .

print("\n")


if __name__ == "__main__":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wdyy about renaming this file to have spec in the name? I quite like and we can recommend people to cross reference the text spec with your code in the main README

torchao/prototype/mx_formats/mx_ops.py Show resolved Hide resolved
Copy link
Member

@msaroufim msaroufim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Really enjoyed reviewing this. some minor nits but we should be good to merge

@msaroufim msaroufim self-requested a review May 28, 2024 04:53
@vkuzo vkuzo force-pushed the 20240523_mx_formats_code_move branch from 77541c5 to ad704f0 Compare May 28, 2024 16:22
Summary:

The MX numerical formats are new low precision formats with recent
acceptance into the OCP spec:
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf

This PR adds a reference native PyTorch implementation of training and
inference primitives for using MX accelerated matrix multiplications.
Currently, we use a reference layout (scale and raw data stored
separately) and an emulated matrix multiplication.

Test Plan:

```
// tests
pytest -s test/prototype/mx_formats/*
// benchmarks
python torchao/prototype/mx_formats/benchmarks/bench_qdq.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
@vkuzo vkuzo force-pushed the 20240523_mx_formats_code_move branch from ad704f0 to 4425d0d Compare May 28, 2024 16:29
@vkuzo
Copy link
Contributor Author

vkuzo commented May 28, 2024

@msaroufim needs a review again since I think this repo is setup to re-require reviews after changes, all of the feedback has been either addressed or explained why not addressed right now.

@vkuzo
Copy link
Contributor Author

vkuzo commented May 28, 2024

and thank you for the review!

@vkuzo vkuzo merged commit a7483f2 into main May 28, 2024
13 checks passed
dbyoung18 pushed a commit to dbyoung18/ao that referenced this pull request Jul 31, 2024
Summary:

The MX numerical formats are new low precision formats with recent
acceptance into the OCP spec:
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf

This PR adds a reference native PyTorch implementation of training and
inference primitives for using MX accelerated matrix multiplications.
Currently, we use a reference layout (scale and raw data stored
separately) and an emulated matrix multiplication.

Test Plan:

```
// tests
pytest -s test/prototype/mx_formats/*
// benchmarks
python torchao/prototype/mx_formats/benchmarks/bench_qdq.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
Summary:
- Removed Int8DynActInt4Weight code
- Use torchao to achieve the same

Test Plan:
python export.py --quant '{"linear:a8w4dq" : {"groupsize": 128}}'
--checkpoint-path stories110M.pt
--params-path params.json
--output-pte-path /tmp/stories110m_a8w4dq.pte
Run
./build/cmake-out/runner_et /tmp/stories110m_a8w4dq.pte -z
/tmp/tokenizer.bin  -n 200 -t 0

Reviewers:

Subscribers:

Tasks:

Tags:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants