Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to save checkpoints when Use low bit optimizers with FSDP1 or FSDP2 #1185

Open
nighting0le01 opened this issue Oct 28, 2024 · 3 comments

Comments

@nighting0le01
Copy link

nighting0le01 commented Oct 28, 2024

only occur when using 8 bit adam

with FSDP1 i run into:

FSDP config
param_dtype: bf16
reduce_dtype: fp32

Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/nfs/asahni/parallel_expts/training/scripts/train.py", line 226, in <module>
    main()
[rank7]: Traceback (most recent call last):
[rank7]:   File "<frozen runpy>", line 198, in _run_module_as_main
[rank7]:   File "<frozen runpy>", line 88, in _run_code
[rank7]:   File "/nfs/asahni/parallel_expts/training/scripts/train.py", line 226, in <module>
[rank7]:     main()
[rank7]:   File "/nfs/asahni/parallel_expts/training/scripts/train.py", line 218, in main
[rank7]:     trainer.fit(
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 538, in fit
[rank7]:     call._call_and_handle_interrupt(
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 47, in _call_and_handle_interrupt
[rank7]:     return trainer_fn(*args, **kwargs)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 574, in _fit_impl
[rank7]:     self._run(model, ckpt_path=ckpt_path)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 981, in _run
[rank7]:     results = self._run_stage()
[rank7]:               ^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 1025, in _run_stage
[rank7]:     self.fit_loop.run()
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py", line 205, in run
[rank7]:     self.advance()
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py", line 363, in advance
[rank7]:     self.epoch_loop.run(self._data_fetcher)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 140, in run
[rank7]:     self.advance(data_fetcher)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 269, in advance
[rank7]:     call._call_callback_hooks(trainer, "on_train_batch_end", batch_output, batch, batch_idx)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 218, in _call_callback_hooks
[rank7]:     fn(trainer, trainer.lightning_module, *args, **kwargs)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 316, in on_train_batch_end
[rank7]:     self._save_topk_checkpoint(trainer, monitor_candidates)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 387, in _save_topk_checkpoint
[rank7]:     self._save_none_monitor_checkpoint(trainer, monitor_candidates)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 715, in _save_none_monitor_checkpoint
[rank7]:     self._save_checkpoint(trainer, filepath)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 390, in _save_checkpoint
[rank7]:     trainer.save_checkpoint(filepath, self.save_weights_only)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 1364, in save_checkpoint
[rank7]:     checkpoint = self._checkpoint_connector.dump_checkpoint(weights_only)
[rank7]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py", line 447, in dump_checkpoint
[rank7]:     optimizer_state = trainer.strategy.optimizer_state(optimizer)
[rank7]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/strategies/fsdp.py", line 539, in optimizer_state
[rank7]:     state_dict = FSDP.optim_state_dict(self.model, optimizer)
[rank7]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1890, in optim_state_dict
[rank7]:     return FullyShardedDataParallel._optim_state_dict_impl(
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1301, in _optim_state_dict_impl
[rank7]:     return _optim_state_dict(
[rank7]:            ^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank7]:     return func(*args, **kwargs)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1976, in _optim_state_dict
[rank7]:     fsdp_osd_state = convert_fn(
[rank7]:                      ^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1799, in _convert_state_with_orig_params
[rank7]:     _gather_all_orig_param_state(
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1693, in _gather_all_orig_param_state
[rank7]:     output_states = _allgather_orig_param_states(
[rank7]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1523, in _allgather_orig_param_states
[rank7]:     dtype, state_buffers = _convert_all_state_info(
[rank7]:                            ^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1382, in _convert_all_state_info
[rank7]:     assert dtype == info.dtype
[rank7]:            ^^^^^^^^^^^^^^^^^^^
@gau-nernst
Copy link
Collaborator

Is this duplicate of #1189? Can I close this and we can discuss over at the other issue?

@nighting0le01
Copy link
Author

hi @gau-nernst this is not exactly duplicate, it is for FSDP1. but we can shift there also if you prefer

@gau-nernst
Copy link
Collaborator

I don't think we actively support FSDP1. If you can create a minimal reproducible example, I can look into it. The errors seem different from those in #1189

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants