-
Notifications
You must be signed in to change notification settings - Fork 185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add FSDP QLoRA test and revert failing PR #403
Conversation
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/403
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit a9f6cca with merge base 6b0ca2d (): BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
@@ -11,10 +11,6 @@ | |||
from torch import Tensor | |||
from torch.distributed.device_mesh import DeviceMesh | |||
from torch._prims_common import make_contiguous_strides_for | |||
from torchao.dtypes.utils import ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#360 consolidate _implements
and _ATEN_OP_OR_TORCH_FN_TABLE
but it breaks torchtune. revert for now to unblock torchtune quickly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you know how exactly this breaks torchtune? is it a versioning thing between saved models and this new model?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the error is TypeError: nf4_detach() missing 1 required positional argument: 'args'
. So there is something incompatiable around _ATEN_OP_OR_TORCH_FN_TABLE[func](*args, **kwargs)
the errors shows up when people start training in TorchTune for the 1st time
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jerryzh168 any thoughts on why this is happening, otherwise are you okay to undo your changes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@drisspg to add a bit more to what @weifengpy already said, the full stack trace is here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def test_qlora_fsdp2(self): | ||
from torch.distributed._composable.fsdp import CPUOffloadPolicy, OffloadPolicy | ||
|
||
self.run_subtests( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
e2e mutli-gpu FSDP + QLoRA test should be able to catch regression in the future
* add FSDP QLoRA test and revert failing PR Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * check pytorch version and cuda for ci Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * revert linter Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Add description of commandline quantization vs quantization json recipe
fix error when running torchtune QLoRA + FSDP2 #380
TypeError: nf4_detach() missing 1 required positional argument: 'args'
torchtune command
pytest -s test/dtypes/test_nf4.py -k test_qlora
: e2e fsdp2 + qlora testpytest -s test/dtypes/test_nf4.py -k test_tensor_copy
: torchtune implementedNF4.clone
, upstream it to TorchAO. This is needed by unit testcopy.deepcopy(model)