Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update fp8_meta amax when copying into Float8Tensor #567

Merged
merged 4 commits into from
Dec 16, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions transformer_engine/pytorch/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,22 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None):
if dst._fp8_dtype == src._fp8_dtype:
dst._data.copy_(src._data)
dst._scale_inv = src._scale_inv.clone()
if dst._fp8_meta is not None:
timmoon10 marked this conversation as resolved.
Show resolved Hide resolved
if src._fp8_meta is None:
src_min, src_max = src.from_float8().aminmax()
src_amax = torch.maximum(-src_min, src_max)
else:
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=src._fp8_meta_forward,
)
fp8_meta_index = src._fp8_meta_index
src_amax = src._fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index]
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=dst._fp8_meta_forward,
)
fp8_meta_index = dst._fp8_meta_index
dst_amax = dst._fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index]
torch.maximum(src_amax, dst_amax, out=dst_amax)
else:
dst.copy_(src.from_float8())

Expand All @@ -582,11 +598,14 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None):
# Update scaling factor if FP8 meta tensors are available
if dst._fp8_meta is None:
scale = dst._scale_inv.reciprocal()
amax = torch.empty_like(scale)
else:
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=dst._fp8_meta_forward,
)
scale = dst._fp8_meta[fp8_meta_key].scale[dst._fp8_meta_index]
fp8_meta_index = dst._fp8_meta_index
scale = dst._fp8_meta[fp8_meta_key].scale[fp8_meta_index]
amax = dst._fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index]
dst._scale_inv = scale.detach().view(1).reciprocal()

# Cast to FP8
Expand All @@ -596,7 +615,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None):
src.view(1,-1),
scale,
dst._data.view(1,-1),
torch.empty_like(dst._scale_inv), # amax
amax,
dst._scale_inv,
dst._fp8_dtype,
)
Expand Down