Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions torchao/float8/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ and composable with key systems such as autograd, ```torch.compile``` and distri
This is the default recipe, with a good balance of performance and accuracy.

```python
import time

import torch
import torch.nn as nn
from torchao.float8 import convert_to_float8_training
Expand All @@ -26,11 +28,12 @@ if not TORCH_VERSION_AT_LEAST_2_5:
raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater")

# create model and sample input
M, K, N = 4096, 8192, 4096
m = nn.Sequential(
nn.Linear(2048, 4096),
nn.Linear(4096, 128),
nn.Linear(K, N, bias=False),
nn.Linear(N, 128, bias=False),
).bfloat16().cuda()
x = torch.randn(4096, 2048, device="cuda", dtype=torch.bfloat16)
x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
optimizer = torch.optim.SGD(m.parameters(), lr=0.1)

# optional: filter modules from being eligible for float8 conversion
Expand All @@ -50,12 +53,26 @@ convert_to_float8_training(m, module_filter_fn=module_filter_fn)
# enable torch.compile for competitive performance
m = torch.compile(m)

# warm up torch.compile for a clean training time measurement
for _ in range(1):
optimizer.zero_grad()
y = m(x)
y.sum().backward()
optimizer.step()

torch.cuda.synchronize()
start_time = time.time()

# toy training loop
for _ in range(10):
optimizer.zero_grad()
y = m(x)
y.sum().backward()
optimizer.step()

torch.cuda.synchronize()
end_time = time.time()
print("Training time:", end_time - start_time)
```

## float8 linear with rowwise scaling
Expand Down
Loading