diff --git a/test/dtypes/test_nf4.py b/test/dtypes/test_nf4.py index 2789315169..3e8b89f9f0 100644 --- a/test/dtypes/test_nf4.py +++ b/test/dtypes/test_nf4.py @@ -1,6 +1,7 @@ import logging import unittest from packaging import version +import math import torch from torch import nn @@ -10,11 +11,17 @@ parametrize, run_tests, ) -from torchao.dtypes.nf4tensor import linear_nf4, NF4Tensor, to_nf4 +from torchao.dtypes.nf4tensor import ( + linear_nf4, + NF4Tensor, + to_nf4, + _INNER_TENSOR_NAMES_FOR_SHARDING, +) import torch.nn.functional as F import io from collections import OrderedDict import torchao +from typing import Tuple, Union bnb_available = False @@ -234,8 +241,7 @@ def test_smoketest_linear_compile(self, dtype: torch.dtype): a_nf4 = torchao.dtypes.to_nf4(a, 16, 2) inp = torch.randn(2, 32, 32, dtype=a.dtype, device=a.device) out3 = torch.compile(torch.nn.functional.linear, mode='max-autotune')(inp, a_nf4) - - + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) @parametrize("shape", [(16, 16), (32, 16)]) @@ -250,7 +256,184 @@ def test_chunk_size_equivalence(self, dtype: torch.dtype, shape, chunk_size): torch.testing.assert_close(nf4_patched.quantized_data, nf4_base.quantized_data) + +class TestFSDPOps(TestCase): + @parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)]) + def test_torch_chunk_valid(self, input_size: Union[Tuple[int], int]): + num_chunks = 2 + nf4_tensor = to_nf4(torch.randn(input_size)) + chunks = list(torch.chunk(nf4_tensor, num_chunks)) + self.assertEqual(len(chunks), num_chunks) + if isinstance(input_size, int): + expected_size0 = input_size // num_chunks + else: + expected_size0 = input_size[0] // num_chunks + for chunk in chunks: + self.assertEqual(chunk.size(0), expected_size0) + + @parametrize("input_size", [511 * 512, (511 * 512,), (511, 512)]) + def test_torch_chunk_invalid_divide(self, input_size: Union[Tuple[int], int]): + num_chunks = 2 + with self.assertRaisesRegex(AssertionError, "Number of scalers must be divisible by scaler block size"): + nf4_tensor = to_nf4(torch.randn(input_size)) + torch.chunk(nf4_tensor, num_chunks) + + @parametrize("input_size", [(512, 512, 512)]) + def test_torch_chunk_invalid_3d(self, input_size: Union[Tuple[int], int]): + num_chunks = 2 + with self.assertRaisesRegex(AssertionError, "expect input tensor dim <= 2"): + nf4_tensor = to_nf4(torch.randn(input_size)) + torch.chunk(nf4_tensor, num_chunks) + + @parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)]) + def test_tensor_new_zeros_valid(self, input_size: Union[Tuple[int], int]): + nf4_tensor = to_nf4(torch.randn(input_size)) + nf4_tensor_zeros = nf4_tensor.new_zeros(input_size) + for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: + inner_tensor = getattr(nf4_tensor_zeros, attr) + self.assertEqual(torch.count_nonzero(inner_tensor), 0) + expected_size = input_size if not isinstance(input_size, int) else (input_size, ) + self.assertEqual(nf4_tensor_zeros.size(), torch.Size(expected_size)) + + @parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)]) + def test_tensor_new_zeros_invalid(self, input_size: Union[Tuple[int], int]): + if isinstance(input_size, int): + new_size = input_size + 1 + elif len(input_size) == 1: + new_size = (input_size[0] + 1, ) + else: + new_size = (input_size[0] + 1, input_size[1]) + nf4_tensor = to_nf4(torch.randn(input_size)) + with self.assertRaisesRegex(NotImplementedError, "aten.new_zeros\\(NF4Tensor\\) with new size"): + nf4_tensor_zeros = nf4_tensor.new_zeros(new_size) + + @parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)]) + def test_tensor_slice_valid(self, input_size: Union[Tuple[int], int]): + nf4_tensor = to_nf4(torch.randn(input_size)) + orig_attrs, _ = nf4_tensor.__tensor_flatten__() + orig_sizes = dict([(attr, getattr(nf4_tensor, attr).size()) for attr in orig_attrs]) + end_idx = input_size if isinstance(input_size, int) else input_size[0] + sliced_tensor = nf4_tensor[:end_idx] + self.assertEqual(nf4_tensor.size(), sliced_tensor.size()) + attrs, _ = sliced_tensor.__tensor_flatten__() + for attr in attrs: + orig_storage = getattr(nf4_tensor, attr).untyped_storage().data_ptr() + sliced_tensor_inner = getattr(sliced_tensor, attr) + self.assertEqual(sliced_tensor_inner.untyped_storage().data_ptr(), orig_storage) + self.assertEqual(sliced_tensor_inner.size(), orig_sizes[attr]) + + def test_tensor_slice_1d_invalid(self): + nf4_tensor = to_nf4(torch.randn(512 * 512)) + with self.assertRaisesRegex(NotImplementedError, "aten.slice\\(NF4Tensor\\) with customized step"): + nf4_tensor[..., ::2] + with self.assertRaisesRegex(NotImplementedError, "aten.slice\\(NF4Tensor\\) with start"): + nf4_tensor[1:] + with self.assertRaisesRegex(NotImplementedError, "aten.slice\\(NF4Tensor\\) with end"): + nf4_tensor[:2] + + def test_tensor_slice_2d_invalid(self): + nf4_tensor = to_nf4(torch.randn((512, 512))) + with self.assertRaisesRegex(NotImplementedError, "aten.slice\\(NF4Tensor\\) with dim"): + nf4_tensor[:, :511] + with self.assertRaisesRegex(NotImplementedError, "aten.slice\\(NF4Tensor\\) with start"): + nf4_tensor[1:] + with self.assertRaisesRegex(NotImplementedError, "aten.slice\\(NF4Tensor\\) with end"): + nf4_tensor[:2] + + @parametrize("input_size", [(512 * 512,), (512, 512)]) + def test_tensor_view_valid(self, input_size: Union[Tuple[int], int]): + nf4_tensor = to_nf4(torch.randn(input_size)) + viewed_tensor = nf4_tensor.view(-1) + self.assertEqual(viewed_tensor.dim(), 1) + self.assertEqual(viewed_tensor.numel(), math.prod(input_size)) + for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: + inner_tensor = getattr(viewed_tensor, attr) + self.assertEqual(inner_tensor.size(0), inner_tensor.numel()) + + @parametrize("input_size", [(512 * 512,), (512, 512)]) + def test_tensor_view_invalid(self, input_size: Union[Tuple[int], int]): + nf4_tensor = to_nf4(torch.randn(input_size)) + if len(input_size) == 1: + with self.assertRaisesRegex(NotImplementedError, "aten.view\\(NF4Tensor\\) with size"): + nf4_tensor.view(input_size) + if len(input_size) == 2: + with self.assertRaisesRegex(NotImplementedError, "aten.view\\(NF4Tensor\\) with len\\(size\\)"): + nf4_tensor.view(input_size) + + @parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)]) + def test_tensor_as_strided_valid(self, input_size: Union[Tuple[int], int]): + nf4_tensor = to_nf4(torch.randn(input_size)) + nf4_tensor_strided = torch.as_strided(nf4_tensor, nf4_tensor.size(), nf4_tensor.stride(), nf4_tensor.storage_offset()) + self.assertEqual(nf4_tensor_strided.size(), nf4_tensor.size()) + self.assertEqual(nf4_tensor_strided.stride(), nf4_tensor.stride()) + self.assertEqual(nf4_tensor_strided.storage_offset(), nf4_tensor.storage_offset()) + for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: + inner_tensor_orig = getattr(nf4_tensor, attr) + inner_tensor_strided = getattr(nf4_tensor_strided, attr) + self.assertEqual(inner_tensor_strided.size(), inner_tensor_orig.size()) + self.assertEqual(inner_tensor_strided.stride(), inner_tensor_orig.stride()) + self.assertEqual(inner_tensor_strided.storage_offset(), inner_tensor_orig.storage_offset()) + + + @parametrize("input_size", [(512 * 512,), (512, 512)]) + def test_tensor_as_strided_invalid(self, input_size: Union[Tuple[int], int]): + nf4_tensor = to_nf4(torch.randn(input_size)) + if len(input_size) == 1: + size = (input_size[0] - 1, ) + else: + size = (input_size[0] - 1, input_size[1]) + with self.assertRaisesRegex(NotImplementedError, "aten.as_strided\\(NF4Tensor\\) different numel"): + torch.as_strided(nf4_tensor, size, nf4_tensor.stride(), nf4_tensor.storage_offset()) + with self.assertRaisesRegex(NotImplementedError, "aten.as_strided\\(NF4Tensor\\) only support original storage offset"): + torch.as_strided(nf4_tensor, nf4_tensor.size(), nf4_tensor.stride(), 1) + + if len(input_size) == 2: + with self.assertRaisesRegex(NotImplementedError, "aten.as_strided\\(NF4Tensor\\) only support continuous stride"): + stride = (nf4_tensor.stride()[1], nf4_tensor.stride()[0]) + torch.as_strided(nf4_tensor, nf4_tensor.size(), stride, nf4_tensor.storage_offset()) + + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_pin_memory(self): + nf4_tensor = to_nf4(torch.randn(512 * 512)) + self.assertFalse(nf4_tensor.is_pinned()) + + nf4_tensor = nf4_tensor.pin_memory() + self.assertTrue(nf4_tensor.is_pinned()) + + nf4_tensor = to_nf4(torch.randn(512 * 512, device='cuda')) + self.assertFalse(nf4_tensor.is_pinned()) + + + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_to_cuda(self): + nf4_tensor = to_nf4(torch.randn(512 * 512)) + self.assertEqual(nf4_tensor.device.type, "cpu") + nf4_tensor = nf4_tensor.to("cuda", non_blocking=True) + self.assertEqual(nf4_tensor.device.type, "cuda") + + nf4_tensor = to_nf4(torch.randn(512 * 512)) + self.assertEqual(nf4_tensor.device.type, "cpu") + nf4_tensor = nf4_tensor.to("cuda") + self.assertEqual(nf4_tensor.device.type, "cuda") + + nf4_tensor = to_nf4(torch.randn(512 * 512)) + self.assertEqual(nf4_tensor.device.type, "cpu") + nf4_tensor = nf4_tensor.to("cuda", torch.bfloat16) + self.assertEqual(nf4_tensor.device.type, "cuda") + self.assertEqual(nf4_tensor.dtype, torch.bfloat16) + + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_to_cpu(self): + nf4_tensor = to_nf4(torch.randn(512 * 512, device='cuda')) + nf4_tensor = nf4_tensor.cpu() + self.assertEqual(nf4_tensor.device.type, "cpu") + for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: + inner_tensor = getattr(nf4_tensor, attr) + self.assertEqual(inner_tensor.device.type, "cpu") + + instantiate_parametrized_tests(TestNF4Linear) +instantiate_parametrized_tests(TestFSDPOps) if __name__ == "__main__": run_tests() diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index 4628ce9949..48249434b7 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -2,19 +2,28 @@ from dataclasses import dataclass import math from typing import Dict, Tuple +import math +import sys +from enum import Enum, auto import torch import torch.nn.functional as F +from torch import Tensor +from torch.distributed.device_mesh import DeviceMesh +from torch._prims_common import make_contiguous_strides_for aten = torch.ops.aten c10d_functional = torch.ops.c10d_functional -from typing import Any, Tuple +from typing import Any, Optional, Tuple, Union, List NF4_OPS_TABLE: Dict[Any, Any] = {} + +_INNER_TENSOR_NAMES_FOR_SHARDING = ["quantized_scalers", "quantization_factor", "quantized_data"] + # Note: Quantize in Chunks # During quantization to NF4, one of the steps to convert from the original float number # to the index of the nearest value in the NF4 format. This can cause a large memory spike @@ -45,11 +54,219 @@ def decorator(func): return decorator -@implements([torch.ops.aten.detach.default, torch.ops.aten.detach]) +def construct_nf4_args(nf4tensor: "NF4Tensor", kwargs: Optional[Dict[str, Any]] = None): + if kwargs is None: + kwargs = {} + tensor_meta = SubclassTensorArgs( + kwargs.get("size", nf4tensor.size()), + kwargs.get("stride", nf4tensor.stride()), + kwargs.get("storage_offset", nf4tensor.storage_offset()), + kwargs.get("dtype", nf4tensor.dtype), + kwargs.get("device", nf4tensor.device), + kwargs.get("requires_grad", nf4tensor.requires_grad), + ) + return ( + tensor_meta, + kwargs.get("block_size", nf4tensor.block_size), + kwargs.get("n_blocks", nf4tensor.n_blocks), + kwargs.get("scaler_block_size", nf4tensor.scaler_block_size), + kwargs.get("quantized_scalers", nf4tensor.quantized_scalers), + kwargs.get("quantization_factor", nf4tensor.quantization_factor), + kwargs.get("scaler_mean", nf4tensor.scaler_mean), + kwargs.get("quantized_data", nf4tensor.quantized_data), + kwargs.get("nf4", nf4tensor.nf4), + ) + + +# __torch_dispatch__ utils: apply aten op to inner tensors +def apply_to_inner_tensors(nf4tensor: "NF4Tensor", aten_op, args, kwargs): + attr_to_tensor = {} + for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: + attr_to_tensor[attr] = aten_op(getattr(nf4tensor, attr), *args, **kwargs) + return attr_to_tensor + +# __torch_function__ utils: call tensor ops from inner tensors +def call_from_inner_tensors(nf4tensor: "NF4Tensor", method_name: str, args, kwargs): + attr_to_tensor = {} + for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: + inner_tensor = getattr(nf4tensor, attr) + func = getattr(inner_tensor, method_name) + attr_to_tensor[attr] = func(*args, **kwargs) + return attr_to_tensor + +class CompareOp(Enum): + EQ = auto() + LT = auto() + +def expect_num_of_args(op: CompareOp, num: int, msg: str): + def decorator(func): + @functools.wraps(func) + def wrapper(aten_op, args, kwargs=None): + if op == CompareOp.LT and not (len(args) < num): + raise NotImplementedError(msg) + return func(aten_op, args, kwargs) + return wrapper + return decorator + +def expect_arg_value_at_k(k: int, op: CompareOp, value: Any, msg: str): + def decorator(func): + @functools.wraps(func) + def wrapper(aten_op, args, kwargs=None): + if op == CompareOp.EQ and not (args[k] == value): + raise NotImplementedError(msg + str(args[k])) + return func(aten_op, args, kwargs) + return wrapper + return decorator + +def expect_args_len_at_k(k: int, op: CompareOp, value: Any, msg: str): + def decorator(func): + @functools.wraps(func) + def wrapper(aten_op, args, kwargs=None): + if op == CompareOp.LT and not (len(args[k]) < value): + raise NotImplementedError(msg + str(len(args[k]))) + elif op == CompareOp.EQ and not (len(args[k]) == value): + raise NotImplementedError(msg + str(len(args[k]))) + return func(aten_op, args, kwargs) + return wrapper + return decorator + + +@implements([torch.ops.aten.detach]) def noop_detach(func, *args, **kwargs): return args[0][0] +@implements( + [ + aten.detach.default, + ] +) +def nf4_detach(aten_op, args, kwargs=None): + nf4tensor = args[0] + updated_attrs = apply_to_inner_tensors(nf4tensor, aten_op, args[1:], kwargs) + return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs)) + + +@implements( + [ + aten.split.Tensor, + ] +) +def nf4_split(aten_op, args, kwargs=None): + if len(args) == 3 and args[2] != 0: + raise NotImplementedError(f"aten.split(NF4Tensor, dim={args[2]})") + nf4tensor = args[0] + num_chunks = nf4tensor.size(0) // args[1] + + attr_to_chunks = {} + for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: + inner_tensor = getattr(nf4tensor, attr) + assert inner_tensor.numel() % num_chunks == 0, f"{attr}.numel() not divisible by {num_chunks}" + chunks = aten_op(inner_tensor, inner_tensor.numel() // num_chunks, **kwargs) + attr_to_chunks[attr] = chunks + + orig_dim = nf4tensor.dim() + if orig_dim == 1: + chunked_size = (nf4tensor.size(0) // num_chunks, ) + elif orig_dim == 2: + chunked_size = (nf4tensor.size(0) // num_chunks, nf4tensor.size(1)) + else: + chunked_size = () + raise NotImplementedError(f"aten.split(NF4Tensor) wherer NF4Tensor.dim() = {orig_dim}") + + nf4_chunks = [] + for idx in range(num_chunks): + updated_attrs = { + "size": chunked_size + } + for attr, chunks in attr_to_chunks.items(): + updated_attrs[attr] = chunks[idx] + nf4_chunks.append(NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs))) + return nf4_chunks + +@implements( + [ + aten.new_zeros.default, + ] +) +@expect_args_len_at_k(1, CompareOp.LT, 3, "aten.view(NF4Tensor) with len(size)=") +def nf4_new_zeros(aten_op, args, kwargs=None): + nf4tensor = args[0] + new_size = tuple(args[1]) + new_size_dim = len(new_size) + if nf4tensor.numel() % math.prod(new_size) != 0: + raise NotImplementedError(f"aten.new_zeros(NF4Tensor) with new size {new_size}") + ratio = nf4tensor.numel() // math.prod(new_size) + + updated_attrs = {} + for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: + inner_tensor = getattr(nf4tensor, attr) + assert inner_tensor.size(0) % ratio == 0, f"{attr}.numel() must be divisible by {ratio}" + inner_tensor = aten_op(inner_tensor, [inner_tensor.size(0) // ratio], **kwargs) + updated_attrs[attr] = inner_tensor + updated_attrs["size"] = new_size + + return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs)) + +@implements( + [ + aten.slice.Tensor, + ] +) +@expect_num_of_args(CompareOp.LT, 5, "aten.slice(NF4Tensor) with customized step") +@expect_arg_value_at_k(1, CompareOp.EQ, 0, "aten.slice(NF4Tensor) with dim=") +@expect_arg_value_at_k(2, CompareOp.EQ, 0, "aten.slice(NF4Tensor) with start=") +def nf4_slice(aten_op, args, kwargs=None): + nf4tensor = args[0] + # for tensor 512 x 512, tensor[:, :512] dispatch to + # aten.slice(dim = 0, end=sys.maxsize) + if not args[3] in [nf4tensor.size(0), sys.maxsize]: + raise NotImplementedError(f"aten.slice(NF4Tensor) with end={args[3]}") + return NF4Tensor(*construct_nf4_args(nf4tensor)) + +@implements( + [ + aten.view.default, + ] +) +@expect_args_len_at_k(1, CompareOp.EQ, 1, "aten.view(NF4Tensor) with len(size)=") +def nf4_view(aten_op, args, kwargs=None): + nf4tensor = args[0] + size = args[1] + if size[0] != -1: + raise NotImplementedError(f"aten.view(NF4Tensor) with size={size}") + updated_attrs = apply_to_inner_tensors(nf4tensor, aten_op, args[1:], kwargs) + updated_attrs.update({ + "size": [nf4tensor.numel()], + "stride": (1, ), + }) + return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs)) + +@implements( + [ + aten.as_strided.default, + ] +) +@expect_args_len_at_k(1, CompareOp.LT, 3, "aten.as_strided(NF4Tensor) only support dim <= 2 but got dim=") +def nf4_as_strided(aten_op, args, kwargs=None): + nf4tensor = args[0] + size = args[1] + stride = tuple(args[2]) + storage_offset = args[3] + if math.prod(size) != nf4tensor.numel(): + raise NotImplementedError(f"aten.as_strided(NF4Tensor) different numel={nf4tensor.numel()} and size={size}") + if stride != make_contiguous_strides_for(size): + raise NotImplementedError(f"aten.as_strided(NF4Tensor) only support continuous stride={make_contiguous_strides_for(size)} but got stride={stride}") + if nf4tensor.storage_offset() != storage_offset: + raise NotImplementedError(f"aten.as_strided(NF4Tensor) only support original storage offset {nf4tensor.storage_offset()} but got {storage_offset}") + kwargs = { + "size": torch.Size(size), + "stride": stride, + "storage_offset": storage_offset, + } + return NF4Tensor(*construct_nf4_args(nf4tensor, kwargs)) + + @implements([torch.ops.aten._to_copy.default]) def _to_copy(func, *args, **kwargs): if not args[0][0].is_contiguous(): @@ -128,6 +345,31 @@ def copy_(func, *args, **kwargs): return original.copy_(same_meta_nf4) +@implements( + [ + aten.is_pinned.default, + ] +) +def nf4_is_pinned(aten_op, args, kwargs=None): + nf4tensor = args[0] + for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: + inner_tensor = getattr(nf4tensor, attr) + if not aten_op(inner_tensor, *(args[1:]), **kwargs): + return False + return True + + +@implements( + [ + aten._pin_memory.default, + ] +) +def nf4_pin_memory(aten_op, args, kwargs=None): + nf4tensor = args[0] + updated_attrs = apply_to_inner_tensors(nf4tensor, aten_op, args[1:], kwargs) + return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs)) + + @dataclass class SubclassTensorArgs: original_shape: torch.Size @@ -232,7 +474,7 @@ def from_tensor( block_size: int, scaler_block_size: int, ): - assert inpt_tensor.dim() <= 2 + assert inpt_tensor.dim() <= 2, f"expect input tensor dim <= 2 but got dim = {inpt_tensor.dim()}" assert ( inpt_tensor.numel() % block_size == 0 ), f"Input tensor must be divisible by block size, got {inpt_tensor.numel()} and {block_size}" @@ -553,6 +795,67 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): return func(*args, **kwargs) + def fsdp_pre_all_gather(self, mesh: DeviceMesh) -> Tuple[Tuple[torch.Tensor, ...], Any]: + return ( + self.quantized_scalers, + self.quantization_factor, + self.quantized_data, + ), ( + SubclassTensorArgs( + self.size(), + self.stride(), + self.storage_offset(), + self.dtype, + self.device, + self.requires_grad, + ), + self.block_size, + self.n_blocks, + self.scaler_block_size, + self.scaler_mean, + self.nf4, + mesh.get_group().size(), + ) + + def fsdp_post_all_gather( + self, + all_gather_outputs: Tuple[torch.Tensor, ...], + metadata: Any, + param_dtype: torch.dtype, + *, + out: Optional[torch.Tensor] = None, + ) -> Union[Tuple[torch.Tensor, Tuple[torch.Tensor, ...]], None]: + (quantized_scalers, quantization_factor, quantized_data) = all_gather_outputs + (tensor_meta, block_size, n_blocks, scaler_block_size, scaler_mean, nf4, pg_size) = metadata + if len(tensor_meta.original_shape) != 2: + raise NotImplementedError(f"only support 2D shape but got dim={len(tensor_meta.original_shape)}") + tensor_meta.original_shape = torch.Size((tensor_meta.original_shape[0] * pg_size, tensor_meta.original_shape[1])) + if out is not None: + # TODO: add param dtype for mixed precision + assert isinstance(out, NF4Tensor), f"{type(out)}" + assert ( + quantized_scalers.untyped_storage().data_ptr() + == out.quantized_scalers.untyped_storage().data_ptr() and + quantization_factor.untyped_storage().data_ptr() + == out.quantization_factor.untyped_storage().data_ptr() and + quantized_data.untyped_storage().data_ptr() + == out.quantized_data.untyped_storage().data_ptr() + ), f"Expects out's data to be the all-gather output" + return + + return NF4Tensor( + tensor_meta, + block_size, + n_blocks, + scaler_block_size, + quantized_scalers, + quantization_factor, + scaler_mean, + quantized_data, + nf4, + ), (quantized_scalers, quantization_factor, quantized_data) + + class LinearNF4(torch.autograd.Function): @staticmethod def forward(ctx, input: torch.Tensor, weight: NF4Tensor): @@ -595,12 +898,33 @@ def decorator(func): @implements_torch_function(torch.Tensor.to) def function_to_dtype(*args, **kwargs): - if isinstance(args[0], NF4Tensor) and isinstance(args[1], torch.dtype): + tensor = args[0] + if isinstance(args[1], torch.dtype): # Tensor.to(dtype, non_blocking, copy, memory_format) - return args[0].get_original_weight().to(*args[1:], **kwargs) + return tensor.get_original_weight().to(*args[1:], **kwargs) + elif ( + isinstance(args[1], torch.device) or ( + isinstance(args[1], str) and ( + args[1] == "cpu" or args[1].startswith("cuda") + ) + ) + ) and len(args) == 2: + # Tensor.to(device, non_blocking) + device = args[1] + updated_attrs = call_from_inner_tensors(tensor, "to", args[1:], kwargs) + updated_attrs["device"] = device + return NF4Tensor(*construct_nf4_args(tensor, updated_attrs)) else: # Tensor.to(device, dtype, non_blocking, copy, memory_format) # Tensor.to(other, non_blocking, copy) raise NotImplementedError( f"NF4Tensor.to({args[1:]}, {kwargs}) is not supported, passing to dispatch" ) + + +@implements_torch_function(torch.Tensor.cpu) +def function_cpu(*args, **kwargs): + nf4tensor = args[0] + updated_attrs = call_from_inner_tensors(nf4tensor, "cpu", args[1:], kwargs) + updated_attrs["device"] = "cpu" + return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs))