Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/multi-gpu.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,15 @@ Start from Stage 1 -> Stage 2 -> Stage 3.

:::

::: {.callout-tip}

Using ZeRO Stage 3 with Single-GPU training

ZeRO Stage 3 can be used for training on a single GPU by manually setting the environment variables:
`WORLD_SIZE=1 LOCAL_RANK=0 MASTER_ADDR=0.0.0.0 MASTER_PORT=29500`

:::

## FSDP {#sec-fsdp}

### Basic FSDP Configuration {#sec-fsdp-config}
Expand Down
18 changes: 17 additions & 1 deletion src/axolotl/loaders/patch_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from functools import cached_property

import addict
import torch
import transformers
from transformers import PretrainedConfig, PreTrainedModel

Expand Down Expand Up @@ -165,10 +166,25 @@ def _apply_gradient_checkpointing_patches(self):
"""Apply patches for gradient checkpointing."""
if self.cfg.gradient_checkpointing in ["unsloth", "offload"]:
from axolotl.monkeypatch.gradient_checkpointing import (
CheckpointFunctionWithCPUOffload,
hf_grad_checkpoint_offload_wrapper,
)

transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper
if (
self.cfg.gradient_checkpointing_kwargs
and "use_reentrant" in self.cfg.gradient_checkpointing_kwargs
and self.cfg.gradient_checkpointing_kwargs["use_reentrant"] is False
):
transformers.modeling_utils.checkpoint = (
hf_grad_checkpoint_offload_wrapper
)
else:
transformers.modeling_utils.checkpoint.CheckpointFunction = (
CheckpointFunctionWithCPUOffload
)
torch.utils.checkpoint.CheckpointFunction = (
CheckpointFunctionWithCPUOffload
)
if self.cfg.gradient_checkpointing == "offload_disk":
from axolotl.monkeypatch.gradient_checkpointing import (
hf_grad_checkpoint_disk_offload_wrapper,
Expand Down
3 changes: 2 additions & 1 deletion src/axolotl/monkeypatch/gradient_checkpointing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from packaging import version

from axolotl.monkeypatch.gradient_checkpointing.offload_cpu import (
from axolotl.monkeypatch.gradient_checkpointing.offload_cpu import ( # noqa: F401
CheckpointFunctionWithCPUOffload,
CPU_Offloaded_Gradient_Checkpointer,
)
from axolotl.monkeypatch.gradient_checkpointing.offload_disk import (
Expand Down
166 changes: 166 additions & 0 deletions src/axolotl/monkeypatch/gradient_checkpointing/offload_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import inspect

import torch
from packaging import version
from torch.utils.checkpoint import (
_get_autocast_kwargs,
_get_device_module,
_infer_device_type,
check_backward_validity,
detach_variable,
get_device_states,
set_device_states,
)

# support different pytorch versions
has_device_type = "device_type" in inspect.signature(set_device_states).parameters

torch_version = version.parse(torch.__version__)

Expand Down Expand Up @@ -60,3 +76,153 @@ def backward(ctx, dY):
) + (
None,
) * len(ctx.args)


# Copyright 2025 Snowflake Inc.
# SPDX-License-Identifier: Apache-2.0
# https://github.com/snowflakedb/ArcticTraining/blob/main/arctic_training/monkey_patches.py
class CheckpointFunctionWithCPUOffload(torch.autograd.Function):
"""
This is a torch/utils/checkpoint.py CheckpointFunction monkey patch that offloads the first tensor to cpu during forward and back to cuda during backward. This allows significant memory savings when using a very long seqlen. e.g. for llama 8b at 100k it's 24GB saved per gpu: `((100_000*4096)*2*32/2**30)`
In the case of a very long seqlen 100k+ the copying to/from cpu overhead is not big, because dense quadratic attention compute will dominate.
"""

@staticmethod
def forward(ctx, run_function, preserve_rng_state, *args):
check_backward_validity(args)
ctx.run_function = run_function
ctx.preserve_rng_state = preserve_rng_state
# Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
ctx.device_type = _infer_device_type(*args)
ctx.device_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs(
ctx.device_type
)
if preserve_rng_state:
ctx.fwd_cpu_state = torch.get_rng_state()
# Don't eagerly initialize the cuda context by accident.
# (If the user intends that the context is initialized later, within their
# run_function, we SHOULD actually stash the cuda state here. Unfortunately,
# we have no way to anticipate this will happen before we run the function.)
ctx.had_device_in_fwd = False
device_module = _get_device_module(ctx.device_type)
if getattr(device_module, "_initialized", False):
ctx.had_device_in_fwd = True
ctx.fwd_devices, ctx.fwd_device_states = get_device_states(*args)

# Save non-tensor inputs in ctx, keep a placeholder None for tensors
# to be filled out during the backward.
ctx.inputs = []
ctx.tensor_indices = []
tensor_inputs = []
# x = None
for i, arg in enumerate(args):
if torch.is_tensor(arg):
# cpu-offload
# we don't want the 2nd tensor - usually it's a shared 4D attn mask which is huge [seq,seq]
# upstream could accept a list of arg indices to offload
if i == 0:
# print(f"{arg.shape=}")
ctx.x_device = arg.device
ctx.x_requires_grad = arg.requires_grad
t = arg.detach().cpu()
else:
t = arg
tensor_inputs.append(t)
ctx.tensor_indices.append(i)
ctx.inputs.append(None)
else:
ctx.inputs.append(arg)

ctx.save_for_backward(*tensor_inputs)

with torch.no_grad():
outputs = run_function(*args)

return outputs

@staticmethod
def backward(ctx, *args):
if (
not torch.autograd._is_checkpoint_valid() # pylint: disable=protected-access
):
raise RuntimeError(
"When use_reentrant=True, torch.utils.checkpoint is incompatible"
" with .grad() or passing an `inputs` parameter to .backward()."
" To resolve this error, you can either set use_reentrant=False,"
" or call .backward() without passing the `inputs` argument."
)
# Copy the list to avoid modifying original list.
inputs = list(ctx.inputs)
tensor_indices = ctx.tensor_indices
tensors = ctx.saved_tensors

# Fill in inputs with appropriate saved tensors.
for i, idx in enumerate(tensor_indices):
if i == 0:
t = (
tensors[i]
.to(ctx.x_device)
.detach()
.requires_grad_(ctx.x_requires_grad)
)
else:
t = tensors[i]
inputs[idx] = t

# Stash the surrounding rng state, and mimic the state that was
# present at this time during forward. Restore the surrounding state
# when we're done.
rng_devices = []
if ctx.preserve_rng_state and ctx.had_device_in_fwd:
rng_devices = ctx.fwd_devices
with torch.random.fork_rng(
devices=rng_devices,
enabled=ctx.preserve_rng_state,
device_type=ctx.device_type,
):
if ctx.preserve_rng_state:
torch.set_rng_state(ctx.fwd_cpu_state)
if ctx.had_device_in_fwd:
if has_device_type:
# newer pytorch (as early as 2.7)
set_device_states(
ctx.fwd_devices,
ctx.fwd_device_states,
device_type=ctx.device_type,
)
else:
# older pytorch (at least 2.4)
set_device_states(ctx.fwd_devices, ctx.fwd_device_states)
detached_inputs = detach_variable(tuple(inputs))

device_autocast_ctx = (
torch.amp.autocast(
device_type=ctx.device_type, **ctx.device_autocast_kwargs
)
if torch.amp.is_autocast_available(ctx.device_type)
else contextlib.nullcontext()
)
with torch.enable_grad(), device_autocast_ctx, torch.amp.autocast("cpu", **ctx.cpu_autocast_kwargs): # type: ignore[attr-defined]
outputs = ctx.run_function(*detached_inputs)

if isinstance(outputs, torch.Tensor):
outputs = (outputs,)

# run backward() with only tensor that requires grad
outputs_with_grad = []
args_with_grad = []
for i in range(len(outputs)): # pylint: disable=consider-using-enumerate
if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
outputs_with_grad.append(outputs[i])
args_with_grad.append(args[i])
if len(outputs_with_grad) == 0:
raise RuntimeError(
"none of output has requires_grad=True, this checkpoint() is not necessary"
)
torch.autograd.backward(outputs_with_grad, args_with_grad)
grads = tuple(
inp.grad if isinstance(inp, torch.Tensor) else None
for inp in detached_inputs
)

return (None, None) + grads
10 changes: 7 additions & 3 deletions src/axolotl/monkeypatch/tiled_mlp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Monkeypatch for Tiled MLP implementation"""

import math
import os

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -29,15 +30,18 @@ def generic_mlp_forward(self_, hs):

mlp_forward = torch.compile(generic_mlp_forward)

is_distributed = int(os.environ.get("WORLD_SIZE", 1)) > 1

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Consider using existing distributed utility function.

The distributed detection logic is correct and aligns with the PR objective. However, there's an existing utility function is_distributed() in src/axolotl/utils/distributed.py that provides similar functionality.

Consider using the existing utility for consistency:

+from axolotl.utils.distributed import is_distributed
+
-        is_distributed = int(os.environ.get("WORLD_SIZE", 1)) > 1
+        is_distributed_env = is_distributed()

Note that the existing function checks for initialized distributed state, which might be more robust than just checking the environment variable.


🏁 Script executed:

#!/bin/bash
# Description: Check if the existing is_distributed function is more comprehensive than WORLD_SIZE check

# Search for is_distributed function usage patterns
rg -A 10 "def is_distributed" src/axolotl/utils/distributed.py

# Check how WORLD_SIZE is used elsewhere in the codebase
rg -n "WORLD_SIZE" --type py

Length of output: 931


Use the centralized is_distributed() helper for consistency

The manual WORLD_SIZE check can be replaced with the existing utility in axolotl.utils.distributed, which also verifies that the distributed backend is initialized.

• File: src/axolotl/monkeypatch/tiled_mlp.py

+ from axolotl.utils.distributed import is_distributed
@@ src/axolotl/monkeypatch/tiled_mlp.py:33
-        is_distributed = int(os.environ.get("WORLD_SIZE", 1)) > 1
+        is_distributed_env = is_distributed()

This change ensures all distributed‐mode checks share the same robust logic.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
is_distributed = int(os.environ.get("WORLD_SIZE", 1)) > 1
# at the top of src/axolotl/monkeypatch/tiled_mlp.py
import os
from axolotl.utils.distributed import is_distributed
# …later, around line 33…
# Replace the manual WORLD_SIZE check with the centralized helper
is_distributed_env = is_distributed()
🤖 Prompt for AI Agents
In src/axolotl/monkeypatch/tiled_mlp.py at line 33, replace the manual check of
the WORLD_SIZE environment variable with the centralized is_distributed() helper
from axolotl.utils.distributed. Import the is_distributed function at the top of
the file and use it to determine distributed mode, ensuring consistent and
robust distributed environment detection across the codebase.


def tiled_mlp_forward(self, x):
input_shape = x.shape
seqlen = input_shape[-2]
hidden = input_shape[-1]
if cfg_num_shards is None:
num_shards = math.ceil(seqlen / hidden)
num_shards_tensor = torch.tensor(num_shards, device=x.device)
dist.all_reduce(num_shards_tensor, op=dist.ReduceOp.MAX)
num_shards = num_shards_tensor.item()
if is_distributed:
num_shards_tensor = torch.tensor(num_shards, device=x.device)
dist.all_reduce(num_shards_tensor, op=dist.ReduceOp.MAX)
num_shards = num_shards_tensor.item()
else:
num_shards = cfg_num_shards

Expand Down
10 changes: 8 additions & 2 deletions src/axolotl/utils/schemas/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,8 +479,14 @@ def pretrain_with_tps(cls, data):
@model_validator(mode="before")
@classmethod
def check_tiled_mlp_deepspeed(cls, data):
if data.get("tiled_mlp", False) and not data.get("deepspeed"):
raise ValueError("tiled_mlp requires deepspeed ZeRO to be enabled")
capabilities = data.get("capabilities")
n_gpu = 0
if capabilities and capabilities.get("n_gpu", 0) >= 1:
n_gpu = capabilities.get("n_gpu", 0)
if data.get("tiled_mlp", False) and (n_gpu > 1 and not data.get("deepspeed")):
raise ValueError(
"tiled_mlp requires deepspeed ZeRO to be enabled for multi-gpu"
)
return data


Expand Down
9 changes: 9 additions & 0 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,15 @@ def setup_deepspeed_env(cfg, stage=None):
# NOTE(djsaunde): The distribued state cannot be initialized prior to the
# ACCELERATE_USE_DEEPSPEED assignment, but it must be initialized some time prior
# to model load.
if int(os.environ.get("WORLD_SIZE", "1")) == 1:
os.environ["WORLD_SIZE"] = "1" # force it in case not set
os.environ["LOCAL_RANK"] = "0" # force it in case not set
os.environ["RANK"] = os.environ.get("LOCAL_RANK", "0")
import deepspeed.comm as dist

dist.init_distributed(
dist_backend="nccl", auto_mpi_discovery=False, dist_init_required=True
)
init_distributed_state()

# If we don't assign this, it doesn't actually get set in the accelerate weakref
Expand Down
Loading