Skip to content

Commit

Permalink
Fix Safe Load for NF4 (#1241)
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg authored Nov 12, 2024
1 parent ccd883b commit 77ca57d
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 45 deletions.
2 changes: 2 additions & 0 deletions ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ include = [
"test/dtypes/test_affine_quantized_float.py",
"test/dtypes/test_nf4.py",
"test/prototype/low_bit_optim/**.py",
"torchao/utils.py",

]

lint.ignore = ["E731"]
6 changes: 2 additions & 4 deletions test/dtypes/test_nf4.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,10 @@ def test_load_from_state_dicts(self, dtype: torch.dtype):
assert base_mod.param.block_size == 32
assert base_mod.param.scaler_block_size == 2

@unittest.skipIf(not torch.cuda.is_available(), "Need cuda for test")
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_load_from_nf4_same_meta(self, dtype: torch.dtype):
"""Tests loading to and from different module state dicts"""
input_tensor = torch.rand(64, device="cuda", dtype=dtype)
input_tensor = torch.rand(64, dtype=dtype)
base_mod = self.TestMod(input_tensor, 32, 2)
state_dict = base_mod.state_dict()
saved_state_dict = self.save_state_dict_to_buffer(state_dict)
Expand All @@ -184,11 +183,10 @@ def test_load_from_nf4_same_meta(self, dtype: torch.dtype):
assert other_mod.param.block_size == 32
assert other_mod.param.scaler_block_size == 2

@unittest.skipIf(not torch.cuda.is_available(), "Need cuda for test")
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_load_from_nf4_diff_meta(self, dtype: torch.dtype):
"""Tests loading to and from different module state dicts"""
input_tensor = torch.rand(128, device="cuda", dtype=dtype)
input_tensor = torch.rand(128, dtype=dtype)
base_mod = self.TestMod(input_tensor, 32, 2)
state_dict = base_mod.state_dict()
saved_state_dict = self.save_state_dict_to_buffer(state_dict)
Expand Down
6 changes: 6 additions & 0 deletions torchao/dtypes/nf4tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from torch._prims_common import make_contiguous_strides_for
from torch.distributed.device_mesh import DeviceMesh

from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

aten = torch.ops.aten

c10d_functional = torch.ops.c10d_functional
Expand Down Expand Up @@ -1043,3 +1045,7 @@ def nf4_constructor(
quantized_data,
nf4,
)


if TORCH_VERSION_AT_LEAST_2_5:
torch.serialization.add_safe_globals([NF4Tensor])
Loading

0 comments on commit 77ca57d

Please sign in to comment.