Skip to content

Commit

Permalink
[subclasses] Use __slots__ for micro optim of flatten/unflatten
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
IvanKobzarev committed Nov 1, 2024
1 parent 2761917 commit 746c10d
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions torchao/float8/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:
# | TP compute with torch.mm(input, weight)

class WeightWithDynamicFloat8CastTensor(torch.Tensor):

__slots__ = "_tensor", "_precomputed_scale", "_linear_mm_config"

@staticmethod
def __new__(
cls,
Expand Down Expand Up @@ -258,6 +261,16 @@ def fsdp_post_all_gather(


class WeightWithDelayedFloat8CastTensor(torch.Tensor):

__slots__ = [
"_tensor",
"_amax_buffer",
"_amax_history_buffer",
"_scale_buffer",
"_linear_mm_config",
"is_amax_initialized"
]

@staticmethod
def __new__(
cls,
Expand Down Expand Up @@ -439,6 +452,9 @@ def fsdp_post_all_gather(


class WeightWithStaticFloat8CastTensor(torch.Tensor):

__slots__ = "_tensor", "_static_scale", "_linear_mm_config"

@staticmethod
def __new__(
cls,
Expand Down

0 comments on commit 746c10d

Please sign in to comment.