-
Notifications
You must be signed in to change notification settings - Fork 185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP2][NF4Tensor][2/n] implement torch.chunk and other ops #150
Changes from 30 commits
0a13e6a
9a56eaa
8180540
95b03e1
3ac9d81
38461b3
bc7a764
8b1d037
7ff6855
d9bcf71
923bef2
e15d244
d4beb8f
f41cb3d
952fbdd
950d9fd
d4eae0b
9444f2c
b2c3c02
9be2de3
2981393
3ced998
3f1e19a
761416a
c656f1e
c56d7e2
d656b93
5c4fe2b
613bf67
3933bfa
9e6b4ec
857b8db
8e3de02
912998b
8926ee1
6f834ce
a8a5aaa
699079d
c8b047c
925602c
e36ab6c
a007027
c352552
c83fdad
574fecd
f27760b
b4f51b9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
import logging | ||
import unittest | ||
from packaging import version | ||
import math | ||
|
||
import torch | ||
from torch import nn | ||
|
@@ -15,6 +16,7 @@ | |
import io | ||
from collections import OrderedDict | ||
import torchao | ||
from typing import Tuple, Union | ||
|
||
|
||
bnb_available = False | ||
|
@@ -236,7 +238,172 @@ def test_smoketest_linear_compile(self, dtype: torch.dtype): | |
out3 = torch.compile(torch.nn.functional.linear, mode='max-autotune')(inp, a_nf4) | ||
|
||
|
||
class TestFSDPOps(TestCase): | ||
@parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)]) | ||
def test_torch_chunk_valid(self, input_size: Union[Tuple[int], int]): | ||
num_chunks = 2 | ||
nf4_tensor = to_nf4(torch.randn(input_size)) | ||
chunks = list(torch.chunk(nf4_tensor, num_chunks)) | ||
self.assertEqual(len(chunks), num_chunks) | ||
if isinstance(input_size, int): | ||
expected_size0 = input_size // num_chunks | ||
else: | ||
expected_size0 = input_size[0] // num_chunks | ||
for chunk in chunks: | ||
self.assertEqual(chunk.size(0), expected_size0) | ||
|
||
@parametrize("input_size", [511 * 512, (511 * 512,), (511, 512)]) | ||
def test_torch_chunk_invalid_divide(self, input_size: Union[Tuple[int], int]): | ||
num_chunks = 2 | ||
with self.assertRaisesRegex(AssertionError, "Number of scalers must be divisible by scaler block size"): | ||
nf4_tensor = to_nf4(torch.randn(input_size)) | ||
torch.chunk(nf4_tensor, num_chunks) | ||
|
||
@parametrize("input_size", [(512, 512, 512)]) | ||
def test_torch_chunk_invalid_3d(self, input_size: Union[Tuple[int], int]): | ||
num_chunks = 2 | ||
with self.assertRaisesRegex(AssertionError, "expect input tensor dim <= 2"): | ||
nf4_tensor = to_nf4(torch.randn(input_size)) | ||
torch.chunk(nf4_tensor, num_chunks) | ||
|
||
@parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)]) | ||
def test_tensor_new_zeros_valid(self, input_size: Union[Tuple[int], int]): | ||
nf4_tensor = to_nf4(torch.randn(input_size)) | ||
nf4_tensor_zeros = nf4_tensor.new_zeros(input_size) | ||
for attr in ["quantized_scalers", "quantization_factor", "quantized_data"]: | ||
inner_tensor = getattr(nf4_tensor_zeros, attr) | ||
self.assertEqual(torch.count_nonzero(inner_tensor), 0) | ||
expected_size = input_size if not isinstance(input_size, int) else (input_size, ) | ||
self.assertEqual(nf4_tensor_zeros.size(), torch.Size(expected_size)) | ||
|
||
@parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)]) | ||
def test_tensor_new_zeros_invalid(self, input_size: Union[Tuple[int], int]): | ||
if isinstance(input_size, int): | ||
new_size = input_size + 1 | ||
elif len(input_size) == 1: | ||
new_size = (input_size[0] + 1, ) | ||
else: | ||
new_size = (input_size[0] + 1, input_size[1]) | ||
nf4_tensor = to_nf4(torch.randn(input_size)) | ||
with self.assertRaisesRegex(NotImplementedError, "aten.new_zeros\\(NF4Tensor\\) with new size"): | ||
nf4_tensor_zeros = nf4_tensor.new_zeros(new_size) | ||
|
||
@parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)]) | ||
def test_tensor_slice_valid(self, input_size: Union[Tuple[int], int]): | ||
nf4_tensor = to_nf4(torch.randn(input_size)) | ||
orig_attrs, _ = nf4_tensor.__tensor_flatten__() | ||
orig_sizes = dict([(attr, getattr(nf4_tensor, attr).size()) for attr in orig_attrs]) | ||
end_idx = input_size if isinstance(input_size, int) else input_size[0] | ||
sliced_tensor = nf4_tensor[:end_idx] | ||
self.assertEqual(nf4_tensor.size(), sliced_tensor.size()) | ||
attrs, _ = sliced_tensor.__tensor_flatten__() | ||
for attr in attrs: | ||
orig_storage = getattr(nf4_tensor, attr).untyped_storage().data_ptr() | ||
sliced_tensor_inner = getattr(sliced_tensor, attr) | ||
self.assertEqual(sliced_tensor_inner.untyped_storage().data_ptr(), orig_storage) | ||
self.assertEqual(sliced_tensor_inner.size(), orig_sizes[attr]) | ||
|
||
def test_tensor_slice_1d_invalid(self): | ||
nf4_tensor = to_nf4(torch.randn(512 * 512)) | ||
with self.assertRaisesRegex(NotImplementedError, "aten.slice\\(NF4Tensor\\) with step"): | ||
nf4_tensor[..., ::2] | ||
with self.assertRaisesRegex(NotImplementedError, "aten.slice\\(NF4Tensor\\) with start"): | ||
nf4_tensor[1:] | ||
with self.assertRaisesRegex(NotImplementedError, "aten.slice\\(NF4Tensor\\) with end "): | ||
nf4_tensor[:2] | ||
|
||
def test_tensor_slice_2d_invalid(self): | ||
nf4_tensor = to_nf4(torch.randn((512, 512))) | ||
with self.assertRaisesRegex(NotImplementedError, "aten.slice\\(NF4Tensor\\) with dim"): | ||
nf4_tensor[:, :511] | ||
with self.assertRaisesRegex(NotImplementedError, "aten.slice\\(NF4Tensor\\) with start"): | ||
nf4_tensor[1:] | ||
with self.assertRaisesRegex(NotImplementedError, "aten.slice\\(NF4Tensor\\) with end"): | ||
nf4_tensor[:2] | ||
|
||
@parametrize("input_size", [(512 * 512,), (512, 512)]) | ||
def test_tensor_view_valid(self, input_size: Union[Tuple[int], int]): | ||
nf4_tensor = to_nf4(torch.randn(input_size)) | ||
viewed_tensor = nf4_tensor.view(-1) | ||
self.assertEqual(viewed_tensor.dim(), 1) | ||
self.assertEqual(viewed_tensor.numel(), math.prod(input_size)) | ||
for attr in ["quantized_scalers", "quantization_factor", "quantized_data"]: | ||
inner_tensor = getattr(viewed_tensor, attr) | ||
self.assertEqual(inner_tensor.size(0), inner_tensor.numel()) | ||
|
||
@parametrize("input_size", [(512 * 512,), (512, 512)]) | ||
def test_tensor_view_invalid(self, input_size: Union[Tuple[int], int]): | ||
nf4_tensor = to_nf4(torch.randn(input_size)) | ||
with self.assertRaisesRegex(NotImplementedError, "aten.view\\(NF4Tensor\\) with size"): | ||
nf4_tensor.view(input_size) | ||
with self.assertRaisesRegex(NotImplementedError, "aten.view\\(NF4Tensor\\) with size"): | ||
nf4_tensor.view(input_size) | ||
|
||
@parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)]) | ||
def test_tensor_as_strided_valid(self, input_size: Union[Tuple[int], int]): | ||
nf4_tensor = to_nf4(torch.randn(input_size)) | ||
nf4_tensor_strided = torch.as_strided(nf4_tensor, nf4_tensor.size(), nf4_tensor.stride(), nf4_tensor.storage_offset()) | ||
self.assertEqual(nf4_tensor_strided.size(), nf4_tensor.size()) | ||
self.assertEqual(nf4_tensor_strided.stride(), nf4_tensor.stride()) | ||
self.assertEqual(nf4_tensor_strided.storage_offset(), nf4_tensor.storage_offset()) | ||
for attr in ["quantized_scalers", "quantization_factor", "quantized_data"]: | ||
inner_tensor_orig = getattr(nf4_tensor, attr) | ||
inner_tensor_strided = getattr(nf4_tensor_strided, attr) | ||
self.assertEqual(inner_tensor_strided.size(), inner_tensor_orig.size()) | ||
self.assertEqual(inner_tensor_strided.stride(), inner_tensor_orig.stride()) | ||
self.assertEqual(inner_tensor_strided.storage_offset(), inner_tensor_orig.storage_offset()) | ||
|
||
|
||
@parametrize("input_size", [(512 * 512,), (512, 512)]) | ||
def test_tensor_as_strided_invalid(self, input_size: Union[Tuple[int], int]): | ||
nf4_tensor = to_nf4(torch.randn(input_size)) | ||
if len(input_size) == 1: | ||
size = (input_size[0] - 1, ) | ||
else: | ||
size = (input_size[0] - 1, input_size[1]) | ||
with self.assertRaisesRegex(NotImplementedError, "aten.as_strided\\(NF4Tensor\\) different numel"): | ||
torch.as_strided(nf4_tensor, size, nf4_tensor.stride(), nf4_tensor.storage_offset()) | ||
with self.assertRaisesRegex(NotImplementedError, "aten.as_strided\\(NF4Tensor\\) only support original storage offset"): | ||
torch.as_strided(nf4_tensor, nf4_tensor.size(), nf4_tensor.stride(), 1) | ||
|
||
if len(input_size) == 2: | ||
with self.assertRaisesRegex(NotImplementedError, "aten.as_strided\\(NF4Tensor\\) only support continuous stride"): | ||
stride = (nf4_tensor.stride()[1], nf4_tensor.stride()[0]) | ||
torch.as_strided(nf4_tensor, nf4_tensor.size(), stride, nf4_tensor.storage_offset()) | ||
|
||
def test_pin_memory(self): | ||
nf4_tensor = to_nf4(torch.randn(512 * 512)) | ||
self.assertFalse(nf4_tensor.is_pinned()) | ||
|
||
nf4_tensor = nf4_tensor.pin_memory() | ||
self.assertTrue(nf4_tensor.is_pinned()) | ||
|
||
nf4_tensor = to_nf4(torch.randn(512 * 512, device='cuda')) | ||
self.assertFalse(nf4_tensor.is_pinned()) | ||
|
||
def test_to_cuda(self): | ||
nf4_tensor = to_nf4(torch.randn(512 * 512)) | ||
self.assertEqual(nf4_tensor.device.type, "cpu") | ||
nf4_tensor = nf4_tensor.to("cuda", non_blocking=True) | ||
self.assertEqual(nf4_tensor.device.type, "cuda") | ||
|
||
nf4_tensor = to_nf4(torch.randn(512 * 512)) | ||
self.assertEqual(nf4_tensor.device.type, "cpu") | ||
nf4_tensor = nf4_tensor.to("cuda") | ||
self.assertEqual(nf4_tensor.device.type, "cuda") | ||
|
||
nf4_tensor = to_nf4(torch.randn(512 * 512)) | ||
self.assertEqual(nf4_tensor.device.type, "cpu") | ||
nf4_tensor = nf4_tensor.to("cuda", torch.bfloat16) | ||
self.assertEqual(nf4_tensor.device.type, "cuda") | ||
self.assertEqual(nf4_tensor.dtype, torch.bfloat16) | ||
|
||
def test_to_cpu(self): | ||
nf4_tensor = to_nf4(torch.randn(512 * 512, device='cuda')) | ||
nf4_tensor.cpu() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so this is just testing against crashes or do also expect the nf4_tensor.device to be cpu? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good catch. this is testing against crashes but i will add assertion on |
||
|
||
instantiate_parametrized_tests(TestNF4Linear) | ||
instantiate_parametrized_tests(TestFSDPOps) | ||
|
||
if __name__ == "__main__": | ||
run_tests() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you mentioned this briefly last week but could you remind me how you figured out these would be the functions that needed to be tested. (I'm thinking ahead with a tutorial for someone who wants to upstream some new exotic dtpye and get it working with fsdp). That's probably a good candidate for what I mean by we should add another smoke test so we know for sure FSDP will work
So I ran the tests locally and they all worked and fast! So this gives me confidence the nf4 tensor now supports many new ops but it doesnt give me confidence that fsdp won't break in some way
I was hoping we could have a smoke test of the sort
fsdp(torch.nn.Sequential(LinearNF4(64,64)))
that would ensure nothing breaks and that fsdp doesn't silently drop the dtype since that functionality wasn't tested for fsdp 1 and we had to rely on twitter to get that signalThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agree that we need a smoke test on
fsdp(model)
. Not sure how to setup a multi-gpu test in torchao though. Is there some .ci files to change? Is there some example in torchAO? I am happy to fill in the actual logic into the template. As a reference, FSDP tests in pytorch are done like thispytorch/test/distributed/_composable/fsdp/test_fully_shard_training.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Something identical should work the machines we have in CI, every commit is already running on 4 A10Gs linux.g5.12xlarge. No existing example since this is our first distributed test
Let's just do this, first thing we meet tomorrow