Skip to content

Commit a2d0045

Browse files
committed
[subclasses] Use __slots__ for micro optim of flatten/unflatten
ghstack-source-id: 29e856540122dd6d0a8d3a522617234af70a6ca3 Pull Request resolved: #1211
1 parent 2761917 commit a2d0045

File tree

2 files changed

+27
-0
lines changed

2 files changed

+27
-0
lines changed

torchao/dtypes/nf4tensor.py

+11
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,17 @@ def get_block_absmax(input_tensor: torch.Tensor, block_size: int) -> torch.Tenso
455455
class NF4Tensor(torch.Tensor):
456456
"""NF4Tensor class for converting a weight to the QLoRA NF4 format"""
457457

458+
__slots__ = [
459+
"quantized_data",
460+
"scaler_mean",
461+
"quantization_factor",
462+
"quantized_scalers",
463+
"nf4",
464+
"block_size",
465+
"n_blocks",
466+
"scaler_block_size",
467+
]
468+
458469
@torch._dynamo.disable
459470
def __new__(
460471
cls,

torchao/float8/fsdp_utils.py

+16
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,9 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:
128128
# | TP compute with torch.mm(input, weight)
129129

130130
class WeightWithDynamicFloat8CastTensor(torch.Tensor):
131+
132+
__slots__ = "_tensor", "_precomputed_scale", "_linear_mm_config"
133+
131134
@staticmethod
132135
def __new__(
133136
cls,
@@ -258,6 +261,16 @@ def fsdp_post_all_gather(
258261

259262

260263
class WeightWithDelayedFloat8CastTensor(torch.Tensor):
264+
265+
__slots__ = [
266+
"_tensor",
267+
"_amax_buffer",
268+
"_amax_history_buffer",
269+
"_scale_buffer",
270+
"_linear_mm_config",
271+
"is_amax_initialized"
272+
]
273+
261274
@staticmethod
262275
def __new__(
263276
cls,
@@ -439,6 +452,9 @@ def fsdp_post_all_gather(
439452

440453

441454
class WeightWithStaticFloat8CastTensor(torch.Tensor):
455+
456+
__slots__ = "_tensor", "_static_scale", "_linear_mm_config"
457+
442458
@staticmethod
443459
def __new__(
444460
cls,

0 commit comments

Comments
 (0)