Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP2][NF4Tensor][2/n] implement torch.chunk and other ops #150

Merged
merged 47 commits into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
0a13e6a
proof of concept for FSDP2 + NF4Tensor
weifengpy Apr 4, 2024
9a56eaa
Merge branch 'main' into main
cpuhrsch Apr 4, 2024
8180540
fsdp extention for tensor subclass
weifengpy Apr 11, 2024
95b03e1
support fp32
weifengpy Apr 15, 2024
3ac9d81
Merge branch 'pytorch-labs:main' into main
weifengpy Apr 16, 2024
38461b3
UNIT TEST FOR STATE DICT
weifengpy Apr 16, 2024
bc7a764
implement to
weifengpy Apr 17, 2024
8b1d037
remove torch.override from torch function
weifengpy Apr 17, 2024
7ff6855
use dtype in compile unit test
weifengpy Apr 17, 2024
d9bcf71
add dtype in all unit test
weifengpy Apr 17, 2024
923bef2
keep original dtype
weifengpy Apr 17, 2024
e15d244
fix linter
weifengpy Apr 17, 2024
d4beb8f
use torch testing @parametrize
weifengpy Apr 17, 2024
f41cb3d
remove unused import
weifengpy Apr 17, 2024
952fbdd
Merge branch 'pytorch-labs:main' into main
weifengpy Apr 17, 2024
950d9fd
sm8 for fp16
weifengpy Apr 17, 2024
d4eae0b
remove sm check for fp16
weifengpy Apr 18, 2024
9444f2c
skip 2.2.2 and below for tracing tensor subclass
weifengpy Apr 18, 2024
b2c3c02
Merge branch 'pytorch-labs:main' into main
weifengpy Apr 18, 2024
9be2de3
include kwargs
weifengpy Apr 19, 2024
2981393
raise unimplemented
weifengpy Apr 19, 2024
3ced998
Merge branch 'main' into main
weifengpy Apr 19, 2024
3f1e19a
Merge branch 'pytorch-labs:main' into main
weifengpy Apr 19, 2024
761416a
fsdp2 ops
weifengpy Apr 19, 2024
c656f1e
better diff layout
weifengpy Apr 19, 2024
c56d7e2
set pg size in metadata
weifengpy Apr 19, 2024
d656b93
remove irrelevant changes
weifengpy Apr 19, 2024
5c4fe2b
add unit test
weifengpy Apr 20, 2024
613bf67
Merge branch 'main' into main
msaroufim Apr 26, 2024
3933bfa
torch.chunk and cpu offloading ops
weifengpy Apr 27, 2024
9e6b4ec
remove strict same metadata check
weifengpy Apr 27, 2024
857b8db
skip tests that needs cuda
weifengpy Apr 27, 2024
8e3de02
use /( in regex match
weifengpy Apr 27, 2024
912998b
fix regex
weifengpy Apr 28, 2024
8926ee1
skip tests if no cuda
weifengpy Apr 28, 2024
6f834ce
skip unit test if no cuda
weifengpy Apr 28, 2024
a8a5aaa
Merge branch 'pytorch:main' into main
weifengpy Apr 28, 2024
699079d
assert cpu device
weifengpy Apr 30, 2024
c8b047c
name args[0] as nf4tensor
weifengpy Apr 30, 2024
925602c
utils for apply to inner tensors and constructor
weifengpy Apr 30, 2024
e36ab6c
use original copy_
weifengpy Apr 30, 2024
a007027
decorator for args check
weifengpy May 1, 2024
c352552
Merge branch 'main' into main
cpuhrsch May 1, 2024
c83fdad
INNER_TENSOR_NAMES_FOR_SHARDING and unify assert in split and new_zeros
weifengpy May 1, 2024
574fecd
Merge branch 'pytorch:main' into main
weifengpy May 1, 2024
f27760b
indicate private constant with _
weifengpy May 1, 2024
b4f51b9
Merge branch 'main' into fsdp2ops
weifengpy May 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 167 additions & 0 deletions test/dtypes/test_nf4.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import unittest
from packaging import version
import math

import torch
from torch import nn
Expand All @@ -15,6 +16,7 @@
import io
from collections import OrderedDict
import torchao
from typing import Tuple, Union


bnb_available = False
Expand Down Expand Up @@ -236,7 +238,172 @@ def test_smoketest_linear_compile(self, dtype: torch.dtype):
out3 = torch.compile(torch.nn.functional.linear, mode='max-autotune')(inp, a_nf4)


class TestFSDPOps(TestCase):
@parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)])
def test_torch_chunk_valid(self, input_size: Union[Tuple[int], int]):
num_chunks = 2
nf4_tensor = to_nf4(torch.randn(input_size))
chunks = list(torch.chunk(nf4_tensor, num_chunks))
self.assertEqual(len(chunks), num_chunks)
if isinstance(input_size, int):
expected_size0 = input_size // num_chunks
else:
expected_size0 = input_size[0] // num_chunks
for chunk in chunks:
self.assertEqual(chunk.size(0), expected_size0)

@parametrize("input_size", [511 * 512, (511 * 512,), (511, 512)])
def test_torch_chunk_invalid_divide(self, input_size: Union[Tuple[int], int]):
num_chunks = 2
with self.assertRaisesRegex(AssertionError, "Number of scalers must be divisible by scaler block size"):
nf4_tensor = to_nf4(torch.randn(input_size))
torch.chunk(nf4_tensor, num_chunks)

@parametrize("input_size", [(512, 512, 512)])
def test_torch_chunk_invalid_3d(self, input_size: Union[Tuple[int], int]):
num_chunks = 2
with self.assertRaisesRegex(AssertionError, "expect input tensor dim <= 2"):
nf4_tensor = to_nf4(torch.randn(input_size))
torch.chunk(nf4_tensor, num_chunks)

@parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)])
def test_tensor_new_zeros_valid(self, input_size: Union[Tuple[int], int]):
nf4_tensor = to_nf4(torch.randn(input_size))
nf4_tensor_zeros = nf4_tensor.new_zeros(input_size)
for attr in ["quantized_scalers", "quantization_factor", "quantized_data"]:
inner_tensor = getattr(nf4_tensor_zeros, attr)
self.assertEqual(torch.count_nonzero(inner_tensor), 0)
expected_size = input_size if not isinstance(input_size, int) else (input_size, )
self.assertEqual(nf4_tensor_zeros.size(), torch.Size(expected_size))

@parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)])
def test_tensor_new_zeros_invalid(self, input_size: Union[Tuple[int], int]):
if isinstance(input_size, int):
new_size = input_size + 1
elif len(input_size) == 1:
new_size = (input_size[0] + 1, )
else:
new_size = (input_size[0] + 1, input_size[1])
nf4_tensor = to_nf4(torch.randn(input_size))
with self.assertRaisesRegex(NotImplementedError, "aten.new_zeros\\(NF4Tensor\\) with new size"):
nf4_tensor_zeros = nf4_tensor.new_zeros(new_size)

@parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)])
def test_tensor_slice_valid(self, input_size: Union[Tuple[int], int]):
nf4_tensor = to_nf4(torch.randn(input_size))
orig_attrs, _ = nf4_tensor.__tensor_flatten__()
orig_sizes = dict([(attr, getattr(nf4_tensor, attr).size()) for attr in orig_attrs])
end_idx = input_size if isinstance(input_size, int) else input_size[0]
sliced_tensor = nf4_tensor[:end_idx]
self.assertEqual(nf4_tensor.size(), sliced_tensor.size())
attrs, _ = sliced_tensor.__tensor_flatten__()
for attr in attrs:
orig_storage = getattr(nf4_tensor, attr).untyped_storage().data_ptr()
sliced_tensor_inner = getattr(sliced_tensor, attr)
self.assertEqual(sliced_tensor_inner.untyped_storage().data_ptr(), orig_storage)
self.assertEqual(sliced_tensor_inner.size(), orig_sizes[attr])

def test_tensor_slice_1d_invalid(self):
nf4_tensor = to_nf4(torch.randn(512 * 512))
with self.assertRaisesRegex(NotImplementedError, "aten.slice\\(NF4Tensor\\) with step"):
nf4_tensor[..., ::2]
with self.assertRaisesRegex(NotImplementedError, "aten.slice\\(NF4Tensor\\) with start"):
nf4_tensor[1:]
with self.assertRaisesRegex(NotImplementedError, "aten.slice\\(NF4Tensor\\) with end "):
nf4_tensor[:2]

def test_tensor_slice_2d_invalid(self):
nf4_tensor = to_nf4(torch.randn((512, 512)))
with self.assertRaisesRegex(NotImplementedError, "aten.slice\\(NF4Tensor\\) with dim"):
nf4_tensor[:, :511]
with self.assertRaisesRegex(NotImplementedError, "aten.slice\\(NF4Tensor\\) with start"):
nf4_tensor[1:]
with self.assertRaisesRegex(NotImplementedError, "aten.slice\\(NF4Tensor\\) with end"):
nf4_tensor[:2]

@parametrize("input_size", [(512 * 512,), (512, 512)])
def test_tensor_view_valid(self, input_size: Union[Tuple[int], int]):
nf4_tensor = to_nf4(torch.randn(input_size))
viewed_tensor = nf4_tensor.view(-1)
self.assertEqual(viewed_tensor.dim(), 1)
self.assertEqual(viewed_tensor.numel(), math.prod(input_size))
for attr in ["quantized_scalers", "quantization_factor", "quantized_data"]:
inner_tensor = getattr(viewed_tensor, attr)
self.assertEqual(inner_tensor.size(0), inner_tensor.numel())

@parametrize("input_size", [(512 * 512,), (512, 512)])
def test_tensor_view_invalid(self, input_size: Union[Tuple[int], int]):
nf4_tensor = to_nf4(torch.randn(input_size))
with self.assertRaisesRegex(NotImplementedError, "aten.view\\(NF4Tensor\\) with size"):
nf4_tensor.view(input_size)
with self.assertRaisesRegex(NotImplementedError, "aten.view\\(NF4Tensor\\) with size"):
nf4_tensor.view(input_size)

@parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)])
def test_tensor_as_strided_valid(self, input_size: Union[Tuple[int], int]):
nf4_tensor = to_nf4(torch.randn(input_size))
nf4_tensor_strided = torch.as_strided(nf4_tensor, nf4_tensor.size(), nf4_tensor.stride(), nf4_tensor.storage_offset())
self.assertEqual(nf4_tensor_strided.size(), nf4_tensor.size())
self.assertEqual(nf4_tensor_strided.stride(), nf4_tensor.stride())
self.assertEqual(nf4_tensor_strided.storage_offset(), nf4_tensor.storage_offset())
for attr in ["quantized_scalers", "quantization_factor", "quantized_data"]:
inner_tensor_orig = getattr(nf4_tensor, attr)
inner_tensor_strided = getattr(nf4_tensor_strided, attr)
self.assertEqual(inner_tensor_strided.size(), inner_tensor_orig.size())
self.assertEqual(inner_tensor_strided.stride(), inner_tensor_orig.stride())
self.assertEqual(inner_tensor_strided.storage_offset(), inner_tensor_orig.storage_offset())


@parametrize("input_size", [(512 * 512,), (512, 512)])
def test_tensor_as_strided_invalid(self, input_size: Union[Tuple[int], int]):
nf4_tensor = to_nf4(torch.randn(input_size))
if len(input_size) == 1:
size = (input_size[0] - 1, )
else:
size = (input_size[0] - 1, input_size[1])
with self.assertRaisesRegex(NotImplementedError, "aten.as_strided\\(NF4Tensor\\) different numel"):
torch.as_strided(nf4_tensor, size, nf4_tensor.stride(), nf4_tensor.storage_offset())
with self.assertRaisesRegex(NotImplementedError, "aten.as_strided\\(NF4Tensor\\) only support original storage offset"):
torch.as_strided(nf4_tensor, nf4_tensor.size(), nf4_tensor.stride(), 1)

if len(input_size) == 2:
with self.assertRaisesRegex(NotImplementedError, "aten.as_strided\\(NF4Tensor\\) only support continuous stride"):
stride = (nf4_tensor.stride()[1], nf4_tensor.stride()[0])
torch.as_strided(nf4_tensor, nf4_tensor.size(), stride, nf4_tensor.storage_offset())

def test_pin_memory(self):
Copy link
Member

@msaroufim msaroufim Apr 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you mentioned this briefly last week but could you remind me how you figured out these would be the functions that needed to be tested. (I'm thinking ahead with a tutorial for someone who wants to upstream some new exotic dtpye and get it working with fsdp). That's probably a good candidate for what I mean by we should add another smoke test so we know for sure FSDP will work

So I ran the tests locally and they all worked and fast! So this gives me confidence the nf4 tensor now supports many new ops but it doesnt give me confidence that fsdp won't break in some way

I was hoping we could have a smoke test of the sort fsdp(torch.nn.Sequential(LinearNF4(64,64))) that would ensure nothing breaks and that fsdp doesn't silently drop the dtype since that functionality wasn't tested for fsdp 1 and we had to rely on twitter to get that signal

Copy link
Contributor Author

@weifengpy weifengpy Apr 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree that we need a smoke test on fsdp(model). Not sure how to setup a multi-gpu test in torchao though. Is there some .ci files to change? Is there some example in torchAO? I am happy to fill in the actual logic into the template. As a reference, FSDP tests in pytorch are done like this pytorch/test/distributed/_composable/fsdp/test_fully_shard_training.py

Copy link
Member

@msaroufim msaroufim Apr 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something identical should work the machines we have in CI, every commit is already running on 4 A10Gs linux.g5.12xlarge. No existing example since this is our first distributed test

Let's just do this, first thing we meet tomorrow

nf4_tensor = to_nf4(torch.randn(512 * 512))
self.assertFalse(nf4_tensor.is_pinned())

nf4_tensor = nf4_tensor.pin_memory()
self.assertTrue(nf4_tensor.is_pinned())

nf4_tensor = to_nf4(torch.randn(512 * 512, device='cuda'))
self.assertFalse(nf4_tensor.is_pinned())

def test_to_cuda(self):
nf4_tensor = to_nf4(torch.randn(512 * 512))
self.assertEqual(nf4_tensor.device.type, "cpu")
nf4_tensor = nf4_tensor.to("cuda", non_blocking=True)
self.assertEqual(nf4_tensor.device.type, "cuda")

nf4_tensor = to_nf4(torch.randn(512 * 512))
self.assertEqual(nf4_tensor.device.type, "cpu")
nf4_tensor = nf4_tensor.to("cuda")
self.assertEqual(nf4_tensor.device.type, "cuda")

nf4_tensor = to_nf4(torch.randn(512 * 512))
self.assertEqual(nf4_tensor.device.type, "cpu")
nf4_tensor = nf4_tensor.to("cuda", torch.bfloat16)
self.assertEqual(nf4_tensor.device.type, "cuda")
self.assertEqual(nf4_tensor.dtype, torch.bfloat16)

def test_to_cpu(self):
nf4_tensor = to_nf4(torch.randn(512 * 512, device='cuda'))
nf4_tensor.cpu()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this is just testing against crashes or do also expect the nf4_tensor.device to be cpu?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch. this is testing against crashes but i will add assertion on nf4_tensor.device.type == 'cpu'


instantiate_parametrized_tests(TestNF4Linear)
instantiate_parametrized_tests(TestFSDPOps)

if __name__ == "__main__":
run_tests()
Loading
Loading