Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP2][NF4Tensor][2/n] implement torch.chunk and other ops #150

Merged
merged 47 commits into from
May 1, 2024
Merged
Changes from 1 commit
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
0a13e6a
proof of concept for FSDP2 + NF4Tensor
weifengpy Apr 4, 2024
9a56eaa
Merge branch 'main' into main
cpuhrsch Apr 4, 2024
8180540
fsdp extention for tensor subclass
weifengpy Apr 11, 2024
95b03e1
support fp32
weifengpy Apr 15, 2024
3ac9d81
Merge branch 'pytorch-labs:main' into main
weifengpy Apr 16, 2024
38461b3
UNIT TEST FOR STATE DICT
weifengpy Apr 16, 2024
bc7a764
implement to
weifengpy Apr 17, 2024
8b1d037
remove torch.override from torch function
weifengpy Apr 17, 2024
7ff6855
use dtype in compile unit test
weifengpy Apr 17, 2024
d9bcf71
add dtype in all unit test
weifengpy Apr 17, 2024
923bef2
keep original dtype
weifengpy Apr 17, 2024
e15d244
fix linter
weifengpy Apr 17, 2024
d4beb8f
use torch testing @parametrize
weifengpy Apr 17, 2024
f41cb3d
remove unused import
weifengpy Apr 17, 2024
952fbdd
Merge branch 'pytorch-labs:main' into main
weifengpy Apr 17, 2024
950d9fd
sm8 for fp16
weifengpy Apr 17, 2024
d4eae0b
remove sm check for fp16
weifengpy Apr 18, 2024
9444f2c
skip 2.2.2 and below for tracing tensor subclass
weifengpy Apr 18, 2024
b2c3c02
Merge branch 'pytorch-labs:main' into main
weifengpy Apr 18, 2024
9be2de3
include kwargs
weifengpy Apr 19, 2024
2981393
raise unimplemented
weifengpy Apr 19, 2024
3ced998
Merge branch 'main' into main
weifengpy Apr 19, 2024
3f1e19a
Merge branch 'pytorch-labs:main' into main
weifengpy Apr 19, 2024
761416a
fsdp2 ops
weifengpy Apr 19, 2024
c656f1e
better diff layout
weifengpy Apr 19, 2024
c56d7e2
set pg size in metadata
weifengpy Apr 19, 2024
d656b93
remove irrelevant changes
weifengpy Apr 19, 2024
5c4fe2b
add unit test
weifengpy Apr 20, 2024
613bf67
Merge branch 'main' into main
msaroufim Apr 26, 2024
3933bfa
torch.chunk and cpu offloading ops
weifengpy Apr 27, 2024
9e6b4ec
remove strict same metadata check
weifengpy Apr 27, 2024
857b8db
skip tests that needs cuda
weifengpy Apr 27, 2024
8e3de02
use /( in regex match
weifengpy Apr 27, 2024
912998b
fix regex
weifengpy Apr 28, 2024
8926ee1
skip tests if no cuda
weifengpy Apr 28, 2024
6f834ce
skip unit test if no cuda
weifengpy Apr 28, 2024
a8a5aaa
Merge branch 'pytorch:main' into main
weifengpy Apr 28, 2024
699079d
assert cpu device
weifengpy Apr 30, 2024
c8b047c
name args[0] as nf4tensor
weifengpy Apr 30, 2024
925602c
utils for apply to inner tensors and constructor
weifengpy Apr 30, 2024
e36ab6c
use original copy_
weifengpy Apr 30, 2024
a007027
decorator for args check
weifengpy May 1, 2024
c352552
Merge branch 'main' into main
cpuhrsch May 1, 2024
c83fdad
INNER_TENSOR_NAMES_FOR_SHARDING and unify assert in split and new_zeros
weifengpy May 1, 2024
574fecd
Merge branch 'pytorch:main' into main
weifengpy May 1, 2024
f27760b
indicate private constant with _
weifengpy May 1, 2024
b4f51b9
Merge branch 'main' into fsdp2ops
weifengpy May 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
use original copy_
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
weifengpy committed Apr 30, 2024
commit e36ab6c2061807861dcdb90269254ed3603609bc
20 changes: 10 additions & 10 deletions torchao/dtypes/nf4tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,18 +288,18 @@ def mm_default(func, *args, **kwargs):
aten.copy_.default,
]
)
def nf4_copy_(aten_op, args, kwargs=None):
assert len(args) == 2 and (kwargs is None or len(kwargs) == 0), "only support aten.copy_.default with 2 args"
original: NF4Tensor = args[0]
copy_in: torch.Tensor = args[1]
def copy_(func, *args, **kwargs):
assert len(args[0]) == 2 and len(kwargs) == 0, "only support aten.copy_.default with 2 args"
original: NF4Tensor = args[0][0]
copy_in: torch.Tensor = args[0][1]

# Base Case

if same_metadata(original, copy_in):
attrs, _ = original.__tensor_flatten__()
for attr in attrs:
inner_tensor_orig = getattr(original, attr)
inner_tensor_copy_in = getattr(copy_in, attr)
aten_op(inner_tensor_orig, inner_tensor_copy_in, **kwargs)
return original
original_tensors = original.__tensor_flatten__()[0]
for tensor_name in original_tensors:
getattr(original, tensor_name).copy_(getattr(copy_in, tensor_name))
return

# Convert Non NF4Tensor into NF4 for copy in
if not isinstance(copy_in, NF4Tensor):
Expand Down
Loading