From 0eb0616de0dffc0f30399ce26544a49849737144 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Tue, 3 Jun 2025 06:28:30 -0700 Subject: [PATCH] update float8 training readme to include time measurement Summary: Update the float8 training example code snippet to include time measurement that properly excludes torch.compile one-time warmup. Also, use larger shapes to demonstrate speedup from float8. Test Plan: copy-paste the snippet and run it, it works. Commenting out float8 shows a slowdown, as expected. Reviewers: Subscribers: Tasks: Tags: --- torchao/float8/README.md | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/torchao/float8/README.md b/torchao/float8/README.md index 65da67c524..99bb80c4bd 100644 --- a/torchao/float8/README.md +++ b/torchao/float8/README.md @@ -17,6 +17,8 @@ and composable with key systems such as autograd, ```torch.compile``` and distri This is the default recipe, with a good balance of performance and accuracy. ```python +import time + import torch import torch.nn as nn from torchao.float8 import convert_to_float8_training @@ -26,11 +28,12 @@ if not TORCH_VERSION_AT_LEAST_2_5: raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater") # create model and sample input +M, K, N = 4096, 8192, 4096 m = nn.Sequential( - nn.Linear(2048, 4096), - nn.Linear(4096, 128), + nn.Linear(K, N, bias=False), + nn.Linear(N, 128, bias=False), ).bfloat16().cuda() -x = torch.randn(4096, 2048, device="cuda", dtype=torch.bfloat16) +x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) optimizer = torch.optim.SGD(m.parameters(), lr=0.1) # optional: filter modules from being eligible for float8 conversion @@ -50,12 +53,26 @@ convert_to_float8_training(m, module_filter_fn=module_filter_fn) # enable torch.compile for competitive performance m = torch.compile(m) +# warm up torch.compile for a clean training time measurement +for _ in range(1): + optimizer.zero_grad() + y = m(x) + y.sum().backward() + optimizer.step() + +torch.cuda.synchronize() +start_time = time.time() + # toy training loop for _ in range(10): optimizer.zero_grad() y = m(x) y.sum().backward() optimizer.step() + +torch.cuda.synchronize() +end_time = time.time() +print("Training time:", end_time - start_time) ``` ## float8 linear with rowwise scaling