Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cast local_scale_tensor to fp32 for precompute of fp8 dynamic scaling #713

Merged
merged 4 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions test/float8/test_fsdp2/test_fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def init_multi_module(self) -> nn.Module:
self.broadcast_module(module)
return module

def init_transformer(self, weight_tying: bool) -> nn.Module:
def init_transformer(self, weight_tying: bool, dtype: torch.dtype | None = None) -> nn.Module:
torch.manual_seed(42)
args = ModelArgs(
n_layers=3,
Expand All @@ -70,6 +70,8 @@ def init_transformer(self, weight_tying: bool) -> nn.Module:
vocab_size=32,
)
module = Transformer(args).cuda()
if dtype is not None:
module = module.to(dtype=dtype)
self.broadcast_module(module)
return module

Expand All @@ -96,6 +98,7 @@ def test_transformer_parity(self):
ScalingType.DELAYED,
],
"compile_transformer_block": [False, True],
"dtype": [torch.float32, torch.bfloat16],
},
self._test_transformer_parity,
)
Expand All @@ -106,6 +109,7 @@ def _test_transformer_parity(
precompute: bool,
scaling_type_weight: ScalingType,
compile_transformer_block: bool,
dtype: torch.dtype | None = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe change this to dtype: Optional[torch.dtype] = None as well. I guess python 3.9 probably does not support torch.dtype | None

):
if not enable_fsdp_float8_all_gather and precompute:
return
Expand All @@ -117,7 +121,7 @@ def _test_transformer_parity(
# latter uses fp8 compute. With fp8 all-gather, FSDP would pre-cast to
# fp8 for that tied weight, incorrectly using fp8 for the embedding.
weight_tying = not enable_fsdp_float8_all_gather
module = self.init_transformer(weight_tying=weight_tying).cuda()
module = self.init_transformer(weight_tying=weight_tying, dtype=dtype)
ref_module = copy.deepcopy(module)
float8_linear_config1 = Float8LinearConfig(
cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
Expand Down
2 changes: 1 addition & 1 deletion torchao/float8/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:
scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate
if amax_tensor.dtype is torch.float16:
scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max)
local_scale_tensor = scale_tensor.to_local()
local_scale_tensor = scale_tensor.to_local().to(dtype=torch.float32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for pinpoint to a specific line. I can also take a look at why max_weights = torch._foreach_norm(weights, ord=math.inf) does not ends up with float32. I was expecting local_scale_tensor to be the same dtype to model.parameters()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just realized that we cast to float32 in

return res.to(torch.float32)

it's good to keep precompute in sync with float8_utils.py. thanks for the fix

for i, float8_linear in enumerate(float8_linears):
float8_linear.weight._local_tensor._precomputed_scale = local_scale_tensor[i]

Expand Down
Loading