Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.compile cast to mxfp8 across dim0 and dim1 should be performant #1788

Open
vkuzo opened this issue Feb 26, 2025 · 1 comment
Open

torch.compile cast to mxfp8 across dim0 and dim1 should be performant #1788

vkuzo opened this issue Feb 26, 2025 · 1 comment
Assignees

Comments

@vkuzo
Copy link
Contributor

vkuzo commented Feb 26, 2025

What this cast is doing

  • reshape the tensor into shape of (-1, block_size), where block_size is usually 32 or 16
  • for each block, calculate a single scale, and then cast that block to torch.float8_e4m3fn
  • do ^ across both dim0 and dim1

What we currently see from inductor is two kernels, one for dim0 and one for dim1:

TORCH_LOGS_FORMAT=short TORCH_LOGS=aot_graphs,output_code python benchmarks/float8/profile_lowp_training.py ~/local/tmp/20250223_test --mx_recipe_name mxfp8_emulated --experiment_filter lowp --mode_filter cast_only_dim0_dim1

Output: https://gist.github.com/vkuzo/7a9f104872790e58b316c7ba477fcbf5

A mx-compliant 32x32 block of a bfloat16 tensor occupies 2kib of memory, so it should easily fit into shared memory of an SM on a modern GPU. We should explore doing this cast across dim0 and dim1 in a tiled fashion, so we can load each tile to shared memory only once.

@vkuzo
Copy link
Contributor Author

vkuzo commented Feb 26, 2025

Tried

torch._inductor.config.triton.prefer_nd_tiling = True                                                         
torch._inductor.config.coordinate_descent_tuning = True                                                       
torch._inductor.config.coordinate_descent_check_all_directions = True                                         

it didn't seem to help

@eellison eellison self-assigned this Feb 27, 2025
vkuzo added a commit that referenced this issue Mar 5, 2025
Summary:

This PR enables `MXLinear` with `mxfp8_cublas` recipe to use
torch.compile.

The current approach is a short term workaround until
pytorch/pytorch#148461 is done. Since we can't
use e8m0 in torchinductor or triton yet, we create a custom op wrapper
around `torch._scaled_mm` which takes `uint8` scales and does the cast to
e8m0 inside the wrapper, where torchinductor can't see it.

Test Plan:

```
// this now works (although performance is not ideal due to #1788)
python benchmarks/float8/profile_lowp_training.py ~/local/tmp/20250305_test --mx_recipe_name mxfp8_cublas

// we can also uncomment the hardware check and run the unit test
pytest test/prototype/mx_formats -s -k test_linear_compile
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 033d817549f80d7d0d8cf549f748411cc1f3ac6a
ghstack-comment-id: 2701679811
Pull Request resolved: #1841
vkuzo added a commit that referenced this issue Mar 5, 2025
Summary:

This PR enables `MXLinear` with `mxfp8_cublas` recipe to use
torch.compile.

The current approach is a short term workaround until
pytorch/pytorch#147873 is done. Since we can't
use e8m0 in torchinductor or triton yet, we create a custom op wrapper
around `torch._scaled_mm` which takes `uint8` scales and does the cast to
e8m0 inside the wrapper, where torchinductor can't see it.

Test Plan:

```
// this now works (although performance is not ideal due to #1788)
python benchmarks/float8/profile_lowp_training.py ~/local/tmp/20250305_test --mx_recipe_name mxfp8_cublas

// we can also uncomment the hardware check and run the unit test
pytest test/prototype/mx_formats -s -k test_linear_compile
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: f3ebd12edcb746b8abf992d00711ce2bdbb7fcf2
ghstack-comment-id: 2701679811
Pull Request resolved: #1841
vkuzo added a commit that referenced this issue Mar 5, 2025
Summary:

This PR enables `MXLinear` with `mxfp8_cublas` recipe to use
torch.compile.

The current approach is a short term workaround until
pytorch/pytorch#147873 is done. Since we can't
use e8m0 in torchinductor or triton yet, we create a custom op wrapper
around `torch._scaled_mm` which takes `uint8` scales and does the cast to
e8m0 inside the wrapper, where torchinductor can't see it.

Test Plan:

```
// this now works (although performance is not ideal due to #1788)
python benchmarks/float8/profile_lowp_training.py ~/local/tmp/20250305_test --mx_recipe_name mxfp8_cublas

// we can also uncomment the hardware check and run the unit test
pytest test/prototype/mx_formats -s -k test_linear_compile
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: e5687e308db0a54c6083c58cfec5cc49626622f1
ghstack-comment-id: 2701679811
Pull Request resolved: #1841
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants