Skip to content

Commit

Permalink
float8: remove unneeded kernel for scale generation (#616)
Browse files Browse the repository at this point in the history
Summary:

The code to create a float8 scale is unnecessarily creating an extra GPU
kernel launch by calling `torch.empty`, removing this.

There is no performance impact, but it does make things easier to debug by reducing log size / making GPU traces simpler.

Test Plan:

```
// extract trace of a linear fwd+bwd with
python benchmarks/float8/profile_linear_float8.py ~/local/tmp/test
// verify that the GPU kernel creating an empty scale tensor is no longer there

// unit tests pass
./test/float8/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
vkuzo authored Aug 7, 2024
1 parent 8bba8ed commit d582f9a
Showing 1 changed file with 1 addition and 3 deletions.
4 changes: 1 addition & 3 deletions torchao/float8/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def amax_to_scale(
float8_dtype: The float8 dtype.
orig_dtype: The original dtype of the tensor.
"""
scale = torch.empty_like(amax, dtype=torch.float32)
if float8_dtype in FP8_TYPES:
res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS)
else:
Expand All @@ -53,8 +52,7 @@ def amax_to_scale(
# to care about this for float32/bfloat16.
if orig_dtype is torch.float16:
res = torch.clamp(res, max=torch.finfo(torch.float16).max)
scale.copy_(res)
return scale
return res.to(torch.float32)


@torch.no_grad()
Expand Down

0 comments on commit d582f9a

Please sign in to comment.