Add FSDP v2 swap memory support + QLoRA compatibility fixes#3167
Conversation
| def noop_pin_memory(self, device=None): | ||
| return self | ||
|
|
||
| torch.Tensor._original_pin_memory = torch.Tensor.pin_memory |
There was a problem hiding this comment.
Thanks so much for this PR @gholmes829!
I'm curious if you needed both CPUOffloadPolicy.pin_memory = False and to patch torch.Tensor._original_pin_memory? Did you try with just the former?
There was a problem hiding this comment.
Of course, thanks for the response @salmanmohammadi!
I started with just disabling the CPUOffloadPolicy.pin_memory = False as that seemed to be the culprit at first. But shortly after it finished the "Loading & Quantizing Model Shards" phase, it crashed with the exception I share below.
Interestingly, the following htop screenshot shows it was already using a handful of swap (after it already began loading the weights onto VRAM) so my guess is it moved on to some phase where it wanted to pin memory in preparation of training or maybe optimizing cuda graph or something.
Another thing to note on this, is that just now I tried the reverse: commenting out CPUOffloadPolicy.pin_memory = False but keeping the tensor pin_memory patch. This configuration seems to work, which I suppose could make sense since FSDP is part of pytorch and likely relies on the tensor pin_memory.
I'll wait to hear your thoughts on this, but I'll plan on updating this PR to remove the redundant CPUOffloadPolicy patch given this new discovery.
Here is the crash trace when I keep the CPUOffloadPolicy.pin_memory = False patch but I comment out the tensor pin memory patching:
[2025-09-18 14:45:30,216] [INFO] [axolotl.cli.checks] Skipping HuggingFace token verification because HF_HUB_OFFLINE is set to True. Only local files will be used.
[2025-09-18 14:45:30,733] [INFO] [axolotl.utils.data.shared] Loading prepared dataset from disk at /workspace/datasets/preprocessed/wizard-32b/3823c563c30437bb88742aafe70941a1...
[W918 14:45:34.116269282 socket.cpp:200] [c10d] The hostname of the client socket cannot be retrieved. err=-3
[W918 14:45:34.116566120 socket.cpp:200] [c10d] The hostname of the client socket cannot be retrieved. err=-3
[W918 14:45:34.117420429 ProcessGroupGloo.cpp:727] Warning: Unable to resolve hostname to a (local) address. Using the loopback address as fallback. Manually set the network interface to bind to with GLOO_SOCKET_IFNAME. (function operator())
[W918 14:45:34.121388642 socket.cpp:200] [c10d] The hostname of the client socket cannot be retrieved. err=-3
[W918 14:45:34.121629680 socket.cpp:200] [c10d] The hostname of the client socket cannot be retrieved. err=-3
[W918 14:45:34.122259041 ProcessGroupGloo.cpp:727] Warning: Unable to resolve hostname to a (local) address. Using the loopback address as fallback. Manually set the network interface to bind to with GLOO_SOCKET_IFNAME. (function operator())
[2025-09-18 14:45:35,084] [INFO] [axolotl.utils.samplers.multipack] gather_len_batches: [51, 51]
[2025-09-18 14:45:35,153] [INFO] [axolotl.utils.trainer] sample_packing_eff_est across ranks: [0.8877140879631042, 0.8877140879631042]
[2025-09-18 14:45:35,155] [INFO] [axolotl.utils.data.sft] Maximum number of steps set at 60
[2025-09-18 14:45:35,782] [INFO] [axolotl.loaders.patch_manager] Applying multipack dataloader patch for sample packing...
[2025-09-18 14:45:36,102] [INFO] [axolotl.integrations.liger.plugin] Applying LIGER to qwen3 with kwargs: {'rope': True, 'cross_entropy': None, 'fused_linear_cross_entropy': True, 'rms_norm': True, 'swiglu': True}
Loading & Quantizing Model Shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [08:17<00:00, 29.28s/it]
[2025-09-18 14:53:54,871] [INFO] [axolotl.loaders.adapter] found linear modules: ['down_proj', 'gate_proj', 'k_proj', 'o_proj', 'q_proj', 'up_proj', 'v_proj']
trainable params: 536,870,912 || all params: 33,298,994,176 || trainable%: 1.6123
df: /root/.triton/autotune: No such file or directory
[2025-09-18 14:54:15,397] [INFO] [axolotl.train] Pre-saving adapter config to /workspace/outputs/axolotl/wizard-32b-test...
[2025-09-18 14:54:15,400] [INFO] [axolotl.train] Pre-saving tokenizer to /workspace/outputs/axolotl/wizard-32b-test...
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None, 'pad_token_id': 151643}.
[2025-09-18 14:54:15,780] [INFO] [axolotl.train] Pre-saving model config to /workspace/outputs/axolotl/wizard-32b-test...
[2025-09-18 14:54:15,795] [INFO] [axolotl.train] Starting trainer...
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None, 'pad_token_id': 151643}.
[2025-09-18 14:54:24,469] [INFO] [axolotl.utils.samplers.multipack] gather_len_batches: [50, 50]
0%| | 0/60 [00:00<?, ?it/s][rank1]: Traceback (most recent call last):
[rank1]: File "<frozen runpy>", line 198, in _run_module_as_main
[rank1]: File "<frozen runpy>", line 88, in _run_code
[rank1]: File "/workspace/axolotl/src/axolotl/cli/train.py", line 121, in <module>
[rank1]: fire.Fire(do_cli)
[rank1]: File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/fire/core.py", line 135, in Fire
[rank1]: component_trace = _Fire(component, args, parsed_flag_args, context, name)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/fire/core.py", line 468, in _Fire
[rank1]: component, remaining_args = _CallAndUpdateTrace(
[rank1]: ^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/fire/core.py", line 684, in _CallAndUpdateTrace
[rank1]: component = fn(*varargs, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/workspace/axolotl/src/axolotl/cli/train.py", line 88, in do_cli
[rank1]: return do_train(parsed_cfg, parsed_cli_args)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/workspace/axolotl/src/axolotl/cli/train.py", line 45, in do_train
[rank1]: model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/workspace/axolotl/src/axolotl/train.py", line 583, in train
[rank1]: execute_training(cfg, trainer, resume_from_checkpoint)
[rank1]: File "/workspace/axolotl/src/axolotl/train.py", line 204, in execute_training
[rank1]: trainer.train(resume_from_checkpoint=resume_from_checkpoint)
[rank1]: File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/transformers/trainer.py", line 2328, in train
[rank1]: return inner_training_loop(
[rank1]: ^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/transformers/trainer.py", line 2671, in _inner_training_loop
[rank1]: with context():
[rank1]: File "/root/miniconda3/envs/py3.11/lib/python3.11/contextlib.py", line 137, in __enter__
[rank1]: return next(self.gen)
[rank1]: ^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/accelerate/accelerator.py", line 1172, in no_sync
[rank1]: with context():
[rank1]: File "/root/miniconda3/envs/py3.11/lib/python3.11/contextlib.py", line 137, in __enter__
[rank1]: return next(self.gen)
[rank1]: ^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1047, in no_sync
[rank1]: _lazy_init(self, self)
[rank1]: File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 140, in _lazy_init
[rank1]: _share_state_and_init_handle_attrs(state, root_module)
[rank1]: File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 226, in _share_state_and_init_handle_attrs
[rank1]: handle.init_flat_param_attributes()
[rank1]: File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank1]: return func(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/torch/distributed/fsdp/_flat_param.py", line 1226, in init_flat_param_attributes
[rank1]: ).pin_memory()
[rank1]: ^^^^^^^^^^^^
[rank1]: RuntimeError: CUDA error: out of memory
[rank1]: CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
[rank1]: For debugging consider passing CUDA_LAUNCH_BLOCKING=1
[rank1]: Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
W0918 14:57:16.341000 200 site-packages/torch/distributed/elastic/multiprocessing/api.py:900] Sending process 243 closing signal SIGTERM
E0918 14:57:22.062000 200 site-packages/torch/distributed/elastic/multiprocessing/api.py:874] failed (exitcode: 1) local_rank: 1 (pid: 244) of binary: /root/miniconda3/envs/py3.11/bin/python3
Traceback (most recent call last):
File "/root/miniconda3/envs/py3.11/bin/accelerate", line 7, in <module>
sys.exit(main())
^^^^^^
File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/accelerate/commands/accelerate_cli.py", line 50, in main
args.func(args)
File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/accelerate/commands/launch.py", line 1226, in launch_command
multi_gpu_launcher(args)
File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/accelerate/commands/launch.py", line 853, in multi_gpu_launcher
distrib_run.run(args)
File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/torch/distributed/run.py", line 883, in run
elastic_launch(
File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 139, in __call__
return launch_agent(self._config, self._entrypoint, list(args))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 270, in launch_agent
raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
============================================================
axolotl.cli.train FAILED
------------------------------------------------------------
Failures:
<NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
time : 2025-09-18_14:57:16
host : 3fedb28fe59e
rank : 1 (local_rank: 1)
exitcode : 1 (pid: 244)
error_file: <N/A>
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================
There was a problem hiding this comment.
Which model are you using?
There was a problem hiding this comment.
@salmanmohammadi the official full precision Qwen3 32B, configured with adapter: qlora and load_in_4bit: true
There was a problem hiding this comment.
I think in an ideal world we would be able to use CPUOffloadPolicy(pin_memory=False) by adding a config field for cpu_offload_pin_memory in fsdp_config, and then to just enable cpu_ram_efficient_loading which would eliminate the GPU OOM you reported in the trace above.
Let me see if I can make this happen. We've currently disabled the combination of load_in_4bit and cpu_ram_efficient_loading in src/axolotl/utils/schemas/validation.py, line 806, but if you're interested in hacking on this you could disable that check and see how far you get - I think some things in bitsandbytes might have changed since I added that check. Make sure you're using the latest bitsandbytes version.
You'll need cpu_ram_efficient_loading: true in your config and your CPUOffloadPolicy(pin_memory=True) fix for this route.
There was a problem hiding this comment.
@salmanmohammadi I have good news!
I was able to remove the validation and get everything working with FSDP v2, as per your recommendations. This is great since I was stuck with v1 before.
Here is current state of implemented changes:
- Added
cpu_offload_pin_memorytofsdp_config(replaces env var I originally proposed) - FSDP v2 with swap fallback now works via
FullyShardedDataParallelPluginwithpin_memory=False(no direct pytorch patching) - FSDP v1 with swap fallback now works via pytorch tensor patch (seems unavoidable due to differences in how pytorch handles v1)
Now this PR enables swap usage and also unlocks some new combinations that are also very helpful for resource constrained / consumer grade hardware.
I've pushed all changed and updated title/description.
There was a problem hiding this comment.
Really impressive work - thank you for sticking with this. I have a couple of suggestions to streamline the implementation if you're willing to go a little further:
- Firstly, I think we should drop the tensor patching changes if they're only required for FSDP1 as we are currently in the process of phasing out FSDP1 in favour of FSDP2 and fixing more edge cases like this helps us get closer to deprecating FSDP1. I'm happy if this feature is only supported with FSDP2.
- It looks like you've enabled this by manually constructing the
FullyShardedDataParallelPluginso we can pass aCPUOffloadPolicyconfigured with your new config field. I actually think we can simplify this without needing to modify a bunch of code and manually constructing the plugin on our side. To achieve this, I would:- Add another line in the
src.axolotl.utils.trainer.setup_fsdp_envsfunction to setos.environ["FSDP_CPU_OFFLOAD_PIN_MEMORY"] = cfg.fsdp_config.cpu_offload_pin_memory. Then, - In
src.axolotl.monkeypatch.accelerate.fsdp2.fsdp2_prepare_model- near the top of this function you can do something like
Some additional validation logic may be needed to ensureif isinstance(fsdp2_plugin.cpu_offload, CPUOffloadPolicy) and os.environ.get("FSDP_CPU_OFFLOAD_PIN_MEMORY", "true") == "false: fsdp2_plugin.cpu_offload.pin_memory = False
offload_paramsis set ifcpu_offload_pin_memoryisfalse. - Add another line in the
What do you think? I'm also happy to pick this up and help take this over the line if this doesn't sound fun.
There was a problem hiding this comment.
@salmanmohammadi that all sounds good to me, I should be able to make the changes later today!
There was a problem hiding this comment.
@salmanmohammadi I've successfully added and tested your suggestions -- definitely much cleaner without v1 support and with using src.axolotl.monkeypatch.accelerate.fsdp2.fsdp2_prepare_model.
Let me know how this looks and if there are any other changes we need to make.
There was a problem hiding this comment.
@salmanmohammadi I ran the pre-commit and formatting is all fixed up now
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the 📝 WalkthroughWalkthroughIntroduces a configurable option to disable FSDP CPU offload pin memory via env/config, updates validation rules and tests for FSDP v2 behavior, adjusts FSDP model preparation to respect the setting, and adds documentation and example config comments about enabling swap by unpinning memory. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Please see the documentation for more information. Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).Please share your feedback with us on this Discord post. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (6)
examples/llama-2/qlora-fsdp.yml (1)
69-69: Align key name with docs and deprecation policy (dropfsdp_prefix).Use the non‑prefixed field to avoid deprecation warnings and match the docs. Also, consider noting the
offload_params: truerequirement in the comment for clarity.Apply this diff:
- # fsdp_cpu_offload_pin_memory: false # uncomment to enable swap memory usage when RAM is insufficient + # cpu_offload_pin_memory: false # uncomment to enable swap memory usage when RAM is insufficient (requires offload_params: true)src/axolotl/utils/trainer.py (1)
598-600: Propagate only for FSDP v2 to reduce ambiguity.This env var is only consumed by the FSDP v2 monkeypatch. Gate it on
fsdp_version == 2to avoid confusion in mixed setups.Apply this diff:
- if cfg.fsdp_config.cpu_offload_pin_memory is not None: - os.environ["FSDP_CPU_OFFLOAD_PIN_MEMORY"] = str(cfg.fsdp_config.cpu_offload_pin_memory).lower() + if str(cfg.fsdp_version) == "2" and cfg.fsdp_config.cpu_offload_pin_memory is not None: + os.environ["FSDP_CPU_OFFLOAD_PIN_MEMORY"] = str( + cfg.fsdp_config.cpu_offload_pin_memory + ).lower()docs/fsdp_qlora.qmd (2)
26-31: Document FSDP v2 requirement and keep key name consistent.Explicitly call out
fsdp_version: 2and usecpu_offload_pin_memory(no prefix) to match validation and the env wiring.Apply this diff:
-## Enabling Swap for FSDP +## Enabling Swap for FSDP (v2) -If available memory is insufficient even after FSDP's CPU offloading, you can enable swap memory usage by setting `cpu_offload_pin_memory: false` alongside `offload_params: true` in FSDP config. +If available memory is insufficient even after FSDP's CPU offloading, you can enable swap memory usage by setting `fsdp_version: 2` and configuring `offload_params: true` with `cpu_offload_pin_memory: false` in your FSDP config. -This disables memory pinning, allowing FSDP to use disk swap space as fallback. Disabling memory pinning itself incurs some performance impact and actually having to use swap space incurs even more. However, it may allow training larger models than otherwise would have been possible due to OOM errors. +This disables memory pinning, allowing FSDP to use disk swap space as fallback. Disabling pinning incurs overhead, and using swap adds more, but it can enable larger models that would otherwise OOM.
2-2: Fix typo in title ("FSDP").Apply this diff:
-title: "FDSP + QLoRA" +title: "FSDP + QLoRA"src/axolotl/monkeypatch/accelerate/fsdp2.py (1)
281-285: Make env parsing robust (accept common falsy strings).Trim and lowercase the env value and treat “0/no/false” equivalently.
Apply this diff:
- if offload_to_cpu and os.environ.get("FSDP_CPU_OFFLOAD_PIN_MEMORY", "") == "false": + if offload_to_cpu and os.environ.get("FSDP_CPU_OFFLOAD_PIN_MEMORY", "").strip().lower() in {"0", "no", "false"}: fsdp2_plugin.cpu_offload.pin_memory = Falsetests/utils/schemas/validation/test_fsdp.py (1)
96-107: Add tests for deprecatedfsdp_‑prefixed keys to prevent validation gaps.Cover both the allowed and error paths when using the prefixed fields, matching real configs/examples.
Apply this diff:
@@ def test_fsdp2_cpu_offload_pin_memory_w_offload_params(self, min_base_cfg): cfg = min_base_cfg | DictDefault( fsdp_config={ "cpu_offload_pin_memory": False, "offload_params": True, }, fsdp_version=2, ) validated_cfg = validate_config(cfg) assert validated_cfg.fsdp_config.cpu_offload_pin_memory is False assert validated_cfg.fsdp_config.offload_params is True + + def test_fsdp2_cpu_offload_pin_memory_prefixed_keys_ok(self, min_base_cfg): + cfg = min_base_cfg | DictDefault( + fsdp_config={ + "fsdp_cpu_offload_pin_memory": False, + "fsdp_offload_params": True, + }, + fsdp_version=2, + ) + validated = validate_config(cfg) + assert validated.fsdp_config.cpu_offload_pin_memory is False + assert validated.fsdp_config.offload_params is True + + def test_fsdp2_cpu_offload_pin_memory_prefixed_requires_offload(self, min_base_cfg): + cfg = min_base_cfg | DictDefault( + fsdp_config={ + "fsdp_cpu_offload_pin_memory": False, + # missing offload flag + }, + fsdp_version=2, + ) + with pytest.raises( + ValueError, + match="disabling cpu_offload_pin_memory requires enabling offload_params", + ): + validate_config(cfg)
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
docs/fsdp_qlora.qmd(1 hunks)examples/llama-2/qlora-fsdp.yml(1 hunks)src/axolotl/monkeypatch/accelerate/fsdp2.py(2 hunks)src/axolotl/utils/schemas/validation.py(1 hunks)src/axolotl/utils/trainer.py(1 hunks)tests/utils/schemas/validation/test_fsdp.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/utils/schemas/validation/test_fsdp.py (2)
src/axolotl/utils/config/__init__.py (1)
validate_config(259-305)src/axolotl/utils/dict.py (1)
DictDefault(6-38)
🪛 Ruff (0.13.1)
src/axolotl/utils/schemas/validation.py
822-822: Multiple statements on one line (colon)
(E701)
826-828: Avoid specifying long messages outside the exception class
(TRY003)
830-832: Avoid specifying long messages outside the exception class
(TRY003)
9484d64 to
e98efe8
Compare
|
Amazing work @gholmes829. Could you please run the pre-commit hooks? Otherwise LGTM. > pip install pre-commit
> pre-commit install
> pre-commit run --all-files |
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
039e7d9 to
aece705
Compare
|
Rebased on main |
Description
This PR enables training larger models than would normally be possible on resource constrained systems. This is accomplished by two main changes:
qlora_sharded_model_loadingwith FSDP v2. Previously, I had to fallback to FSDP v1 which is clearly undesirable.Key changes:
cpu_offload_pin_memory: falseis setCPUOffloadPolicy(pin_memory=False)cpu_ram_efficient_loadingcombinationoffload_params: truewhen disabling pin_memorycpu_ram_efficient_loadingcombination now worksNow viable:
Motivation and Context
I have been messing around with several different popular fine tuning frameworks. One thing that drew me to Axolotl in particular, was its mention of special FSDP + QLoRa support, touted as potentially enabling training 70B model on 2x3090s / 48GB VRAM. I was sold and wasted no time trying it out.
...only to realize I hadn't stopped to read some of the fine print. The training feat in question was indeed conducted by only 2 consumer grade GPUs -- with the help of 128GB RAM. I only had 32 GB RAM to spare, and I eventually ran out of experimental optimization flags to hopelessly try out.
But then I realized, why should I be stopped from training a larger model? I had plenty of disk space and wasn't in a great hurry to fine tune my models, so I configured my system to provide 128GB swap memory to the container housing Axolotl.
...only to realize I was still getting OOM errors, this time from the CPU. Diving deeper, I saw this was specifically due to limitations of pinned memory, something that pytorch (both FSDP and more generic PT internals) make fond use of. Unlike DeepSpeed, FSDP and PT did not expose this option.
And thus this PR is born!
I thought there must have been a good reason for this to not already be implemented, so I ran it expecting to uncover some more sinister error or discover fine tuning time would be measured in decades.
Instead, after my hours of toil, I found myself with a nice QLoRa adapter ~30 minutes later for my dense Qwen3 32B. This one wasn't super high rank, but nonetheless succeeded to influence the models style. High epoch and higher rank LoRa finetunes can take several hours for me but that is much better than not being able to do it at all.
How has this been tested?
tests\utils\schemas\validation\test_fsdp.pyScreenshots (if appropriate)
Likely N/A -- but I could add some charts of memory usage if anyone wants.
Types of changes
I'd classify this as hybrid of feature / bug fix.
Summary by CodeRabbit
New Features
Documentation
Tests