Skip to content

[subclasses] Use __slots__ for micro optim of flatten/unflatten #1211

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: gh/IvanKobzarev/3/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions torchao/dtypes/nf4tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,17 @@ def get_block_absmax(input_tensor: torch.Tensor, block_size: int) -> torch.Tenso
class NF4Tensor(torch.Tensor):
"""NF4Tensor class for converting a weight to the QLoRA NF4 format"""

__slots__ = [
"quantized_data",
"scaler_mean",
"quantization_factor",
"quantized_scalers",
"nf4",
"block_size",
"n_blocks",
"scaler_block_size",
]

@torch._dynamo.disable
def __new__(
cls,
Expand Down
16 changes: 16 additions & 0 deletions torchao/float8/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:
# | TP compute with torch.mm(input, weight)

class WeightWithDynamicFloat8CastTensor(torch.Tensor):

__slots__ = "_tensor", "_precomputed_scale", "_linear_mm_config"

@staticmethod
def __new__(
cls,
Expand Down Expand Up @@ -258,6 +261,16 @@ def fsdp_post_all_gather(


class WeightWithDelayedFloat8CastTensor(torch.Tensor):

__slots__ = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious is __slots__ derived from __tensor_flatten__ ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, here I just collected all attributes used in tensor_flatten, tensor_unflatten

"_tensor",
"_amax_buffer",
"_amax_history_buffer",
"_scale_buffer",
"_linear_mm_config",
"is_amax_initialized"
]

@staticmethod
def __new__(
cls,
Expand Down Expand Up @@ -439,6 +452,9 @@ def fsdp_post_all_gather(


class WeightWithStaticFloat8CastTensor(torch.Tensor):

__slots__ = "_tensor", "_static_scale", "_linear_mm_config"

@staticmethod
def __new__(
cls,
Expand Down
Loading