Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
322a24a
use new upstream branches for nd-parallelism
winglian Jul 23, 2025
b0633b1
handle tp load
winglian Jul 23, 2025
277f974
make sure to return data for validation
winglian Jul 23, 2025
36a06be
update tp validation
winglian Jul 23, 2025
fe5805f
handle none checks
winglian Jul 23, 2025
38fed08
fix for accelerator state getting reset and missing schema
winglian Jul 23, 2025
b563aa0
use latest transformers on main with fix
winglian Jul 23, 2025
f315e1e
workaround for fsdp2 optimizer save failures
winglian Jul 23, 2025
df06050
use updated mesh builder
winglian Jul 23, 2025
533c2ae
no need to patch data loader anymore
winglian Jul 23, 2025
0cfbd46
register ring attn using device mesh instead of static size
winglian Jul 24, 2025
5453f41
fix cp dim
winglian Jul 24, 2025
cfa33ca
use updated branch w fix
winglian Jul 24, 2025
a50014d
workaround for upstream waiting on pr
winglian Jul 24, 2025
ca16d7e
updating token count
salmanmohammadi Jul 24, 2025
75b3b49
force previous behavior for loss for CP
winglian Jul 24, 2025
8ec048d
updating for upstream
salmanmohammadi Jul 24, 2025
34d5670
use transformers main for now
winglian Jul 24, 2025
428bc04
chore: lint
winglian Jul 24, 2025
d00515d
don't bother with device mesh for configurations that don't require it
winglian Jul 24, 2025
d50e1e4
check if parallelism config is set before setting use_configured_state
winglian Jul 24, 2025
36b034e
upstream patches
winglian Jul 25, 2025
bd3cbe8
more implementation fixes
winglian Jul 25, 2025
82e5991
Fix parallelism config setup
winglian Jul 25, 2025
c40cdfd
fixing error handling, SP
Jul 25, 2025
d9e7dbb
adding DP replicate, more validation, gpu mem logging
Jul 25, 2025
2597653
lint and fix dangling no-op statement
winglian Jul 25, 2025
d7e5b02
better handling of when to use parallel config, basically, not ddp
winglian Jul 25, 2025
b7e7581
fix the checks :facepalm:
winglian Jul 25, 2025
270967e
comments
Jul 25, 2025
b6a03be
nits
Jul 25, 2025
d17b445
linting
salmanmohammadi Jul 25, 2025
8ede662
handle value error gracefully
winglian Jul 25, 2025
60c9620
improve handling for tests
winglian Jul 25, 2025
08eadb6
handle process count in ci and fix parallel setting and add tests
winglian Jul 25, 2025
6e28602
remove print and add another test case
winglian Jul 25, 2025
9e0d86b
update to latest transformers and only install latest vllm on 2.7.1
winglian Jul 25, 2025
6c308aa
use intermediate loader for ParallelismConfig while we wait for release
winglian Jul 26, 2025
44425a4
add missing class
winglian Jul 26, 2025
ec4ed1e
improve parallelism config check
winglian Jul 26, 2025
3271815
use current releasE
winglian Jul 26, 2025
4948f38
fixes for broken tp
winglian Jul 26, 2025
36429ab
no-pack, no pad CP seems to timeout
winglian Jul 26, 2025
e5c14d8
cast to bool and debug out for now
winglian Jul 26, 2025
9d0f382
fix vllm version and fix upstream tp issues
winglian Jul 26, 2025
b7f9027
better handling to not handle for ddp
winglian Jul 26, 2025
f72ef0a
Fix vllm in requires
winglian Jul 27, 2025
0248d93
fix logic from merge
winglian Jul 27, 2025
b1ab8cc
fix patches
winglian Jul 27, 2025
d147a3f
lint
winglian Jul 28, 2025
cc933ae
use integration branch for next transformers release w fixeS
winglian Jul 28, 2025
fce55f4
remove accidentally commited yaml in root
winglian Jul 28, 2025
3e0fb45
remove accidental file add
winglian Jul 29, 2025
2975c6c
don't use pre-release workaround workflow
winglian Jul 29, 2025
ad75e15
use updated accelerate
winglian Jul 29, 2025
93b37fc
more fixes
winglian Jul 30, 2025
8d80ef9
fix saving fsdp2
winglian Jul 30, 2025
b759168
fsdp2 with sharded state dicts should work now
winglian Jul 30, 2025
51b2b7b
cleanup from PR feedback
winglian Jul 30, 2025
668e15c
guard on paralle config
winglian Jul 30, 2025
7e90cce
more cleanup
winglian Jul 30, 2025
eb451f8
deprecated field and migrate [skip e2e]
winglian Jul 30, 2025
23369a9
fix fdsp checkpoint save across trainers
winglian Jul 31, 2025
44d293c
removing valuerror
salmanmohammadi Jul 31, 2025
f8df5bf
skip tp test
winglian Jul 31, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cicd/multigpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
set -e

# Only run two tests at a time to avoid OOM on GPU (with coverage collection)
pytest -v -n2 \
pytest -v --durations=10 -n2 \
--ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ \
--ignore=/workspace/axolotl/tests/e2e/multigpu/patched/ \
/workspace/axolotl/tests/e2e/multigpu/ \
Expand Down
5 changes: 4 additions & 1 deletion cicd/single_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@
def run_cmd(cmd: str, run_folder: str):
import subprocess # nosec

sp_env = os.environ.copy()
sp_env["AXOLOTL_DATASET_PROCESSES"] = "8"

# Propagate errors from subprocess.
if exit_code := subprocess.call(cmd.split(), cwd=run_folder): # nosec
if exit_code := subprocess.call(cmd.split(), cwd=run_folder, env=sp_env): # nosec
exit(exit_code) # pylint: disable=consider-using-sys-exit
12 changes: 6 additions & 6 deletions docs/sequence_parallelism.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ To enable sequence parallelism, add the following to your configuration file:

```yaml
# Set to a divisor (> 1) of the number of GPUs available
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
context_parallel_size: 4 # Split sequences across 4 GPUs
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
heads_k_stride: 1
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
# "varlen_llama3" when `sample_packing: true`, and "batch_ring" otherwise.
ring_attn_func:
```

The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:
The `context_parallel_size` should be a divisor of the total number of GPUs. For example:

- With 8 GPUs, valid values would be 2, 4, or 8
- With 4 GPUs, valid values would be 2 or 4
Expand Down Expand Up @@ -66,7 +66,7 @@ sequence_len: 8192

...

sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
context_parallel_size: 4 # Split each sequence into 4 parts, one per GPU
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
heads_k_stride: 1
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
Expand All @@ -89,12 +89,12 @@ Sequence parallelism is compatible with Axolotl's sample packing functionality.

## Effect on Batch Size

When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because:
When using sequence parallelism, your effective global batch size is **divided** by the `context_parallel_size`. This happens because:

- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
- Each group of `context_parallel_size` GPUs works on the same batch (just different parts of each sequence)
- The number of batches processed per step decreases

For example:
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step
- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
- With 8 GPUs and `context_parallel_size=4`: Only 2 different batches processed per step (each split across 4 GPUs)
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4
2 changes: 1 addition & 1 deletion examples/alst/llama3-8b-deepspeed-alst.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ min_sample_len: 200_000
sample_packing: true

tiled_mlp: true
sequence_parallel_degree: 8
context_parallel_size: 8
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin

Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ packaging==23.2

huggingface_hub>=0.33.0
peft==0.16.0
transformers==4.54.0
transformers==4.54.1
tokenizers>=0.21.1
accelerate==1.9.0
accelerate @ git+https://github.com/huggingface/accelerate.git@9359a0194f210624f1e6e85c3d838fdd55c11152
datasets==4.0.0
deepspeed>=0.17.0
trl==0.20.0
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,13 @@ def parse_requirements(extras_require_map):
extras_require_map.pop("vllm")
else:
_install_requires.append("xformers==0.0.31")
extras_require_map["vllm"] = ["vllm>=0.10.0"]
elif (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.29.post3")
# since we only support 2.6.0+cu126
_dependency_links.append("https://download.pytorch.org/whl/cu126")
extras_require_map["vllm"] = ["vllm==0.8.5.post1"]
extras_require_map.pop("vllm")
elif (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/cli/merge_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
load_in_8bit=False,
load_in_4bit=False,
flash_attention=False,
sequence_parallel_degree=None,
context_parallel_size=None,
deepspeed=None,
fsdp=None,
fsdp_config=None,
Expand Down
26 changes: 25 additions & 1 deletion src/axolotl/core/builders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@
from typing import Any

import torch
from accelerate import PartialState
from transformers import (
TrainerCallback,
)
from transformers.trainer_pt_utils import AcceleratorConfig
from transformers.training_args import OptimizerNames

from axolotl.integrations.base import PluginManager
Expand Down Expand Up @@ -434,8 +436,30 @@ def _configure_torch_compile(self, training_args_kwargs: dict):
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode

def _configure_accelerator_config(self, training_args_kwargs: dict):
partial_state = PartialState()
has_pc_attr = (
hasattr(partial_state, "parallelism_config")
and partial_state.parallelism_config
)
has_pc_key = (
"parallelism_config"
in partial_state._shared_state # pylint: disable=protected-access
and partial_state._shared_state[ # pylint: disable=protected-access
"parallelism_config"
]
)
use_configured_state = has_pc_attr or has_pc_key
if self.cfg.accelerator_config:
training_args_kwargs["accelerator_config"] = self.cfg.accelerator_config
use_configured_state = self.cfg.accelerator_config.pop(
"use_configured_state", use_configured_state
)
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
use_configured_state=use_configured_state, **self.cfg.accelerator_config
)
else:
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
use_configured_state=use_configured_state,
)

def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
if self.cfg.activation_offloading is True:
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/core/builders/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _get_trainer_cls(self, trainer_kwargs: dict):

if self.cfg.rl is RLType.GRPO:
trainer_cls = GRPOStrategy.get_trainer_class(
sequence_parallel=self.cfg.sequence_parallel_degree > 1
sequence_parallel=self.cfg.context_parallel_size > 1
)
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))

Expand Down
2 changes: 2 additions & 0 deletions src/axolotl/core/trainers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from axolotl.core.trainers.mixins import (
ActivationOffloadingMixin,
CheckpointSaveMixin,
DistributedParallelMixin,
OptimizerMixin,
PackingMixin,
RngLoaderMixin,
Expand All @@ -50,6 +51,7 @@ class AxolotlTrainer(
RngLoaderMixin,
CheckpointSaveMixin,
ActivationOffloadingMixin,
DistributedParallelMixin,
Trainer,
):
"""Extend the base Trainer for axolotl helpers"""
Expand Down
13 changes: 11 additions & 2 deletions src/axolotl/core/trainers/dpo/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
from torch import nn
from trl import DPOTrainer

from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
from axolotl.core.trainers.mixins import (
DistributedParallelMixin,
RngLoaderMixin,
SchedulerMixin,
)
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin
from axolotl.core.trainers.utils import (
sanitize_kwargs_for_ds_tagging,
Expand All @@ -17,7 +21,12 @@


class AxolotlDPOTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, DPOTrainer
RngLoaderMixin,
SchedulerMixin,
OptimizerMixin,
OptimizerInitMixin,
DPOTrainer,
DistributedParallelMixin,
):
"""Extend the base DPOTrainer for axolotl helpers."""

Expand Down
6 changes: 3 additions & 3 deletions src/axolotl/core/trainers/grpo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,14 @@ def set_training_args_kwargs(cls, cfg: DictDefault) -> dict[str, Any]:
grpo_args_kwargs["log_completions"] = trl.log_completions
grpo_args_kwargs["num_completions_to_print"] = trl.num_completions_to_print

if cfg.context_parallel_size > 1:
grpo_args_kwargs["context_parallel_size"] = cfg.context_parallel_size

if trl.importance_sampling_level is not None:
grpo_args_kwargs["importance_sampling_level"] = (
trl.importance_sampling_level
)

if cfg.sequence_parallel_degree > 1:
grpo_args_kwargs["sequence_parallel_degree"] = cfg.sequence_parallel_degree

if trl.reward_weights:
grpo_args_kwargs["reward_weights"] = trl.reward_weights

Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/core/trainers/grpo/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
"""Axolotl GRPO Config for GRPO training"""

sequence_parallel_degree: int | None = None
context_parallel_size: int | None = None
12 changes: 6 additions & 6 deletions src/axolotl/core/trainers/grpo/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class SequenceParallelRepeatRandomSampler(Sampler):
- Data is properly distributed across SP groups.

In the table below, the values represent dataset indices. Each SP group has
`sequence_parallel_degree = 2` GPUs working together on the same data. There are 2
`context_parallel_size = 2` GPUs working together on the same data. There are 2
SP groups (SP0 and SP1), with `world_size = 4` total GPUs.

Sequence Parallel Groups
Expand All @@ -45,7 +45,7 @@ class SequenceParallelRepeatRandomSampler(Sampler):
rank: Rank of current process.
batch_size: Number of samples per batch.
repeat_count: How many times to repeat the full sampling process.
sequence_parallel_degree: Number of ranks in a sequence parallel group.
context_parallel_size: Number of ranks in a sequence parallel group.
shuffle: Whether to shuffle the dataset.
seed: Random seed for shuffling.
drop_last: Whether to drop the last incomplete batch.
Expand All @@ -59,7 +59,7 @@ def __init__(
rank: int,
batch_size: int = 1,
repeat_count: int = 1,
sequence_parallel_degree: int = 1,
context_parallel_size: int = 1,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
Expand All @@ -77,9 +77,9 @@ def __init__(
self.rank = rank

# Sequence parallelism parameters
self.sequence_parallel_degree = sequence_parallel_degree
self.num_sp_groups = world_size // sequence_parallel_degree
self.sp_group_id = rank // sequence_parallel_degree
self.context_parallel_size = context_parallel_size
self.num_sp_groups = world_size // context_parallel_size
self.sp_group_id = rank // context_parallel_size

# Adjust dataset size for distributed sampling
self.num_samples = len(self.dataset)
Expand Down
39 changes: 24 additions & 15 deletions src/axolotl/core/trainers/grpo/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@
from trl.trainer.utils import pad

from axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampler
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
from axolotl.core.trainers.mixins import (
DistributedParallelMixin,
RngLoaderMixin,
SchedulerMixin,
)
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin
from axolotl.monkeypatch.ring_attn import get_ring_attn_group

Expand All @@ -53,7 +57,12 @@


class AxolotlGRPOTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, GRPOTrainer
RngLoaderMixin,
SchedulerMixin,
OptimizerMixin,
OptimizerInitMixin,
DistributedParallelMixin,
GRPOTrainer,
):
"""Extend the base GRPOTrainer for axolotl helpers"""

Expand Down Expand Up @@ -100,7 +109,7 @@ def __init__(

# Get number of SP groups (number of processes divided by SP degree)
num_processes = self.accelerator.num_processes
num_sp_groups = num_processes // self.args.sequence_parallel_degree
num_sp_groups = num_processes // self.args.context_parallel_size

# Calculate batch size per SP group (not per process)
sp_group_batch_size = self.args.per_device_train_batch_size * num_sp_groups
Expand Down Expand Up @@ -130,7 +139,7 @@ def __init__(

if self.num_generations not in possible_values:
raise ValueError(
f"With sequence parallelism (degree {self.args.sequence_parallel_degree}), "
f"With sequence parallelism (degree {self.args.context_parallel_size}), "
f"the eval batch size per SP group ({num_sp_groups} x {self.args.per_device_eval_batch_size}) "
f"must be evenly divisible by the number of generations per prompt "
f"({self.num_generations}). Given the current eval batch size, "
Expand Down Expand Up @@ -167,9 +176,9 @@ def _get_train_sampler(self) -> Sampler:
rank=self.rank,
batch_size=effective_batch_size
// self.num_generations
// self.args.sequence_parallel_degree,
// self.args.context_parallel_size,
repeat_count=self.num_iterations * self.args.gradient_accumulation_steps,
sequence_parallel_degree=self.args.sequence_parallel_degree,
context_parallel_size=self.args.context_parallel_size,
shuffle=True,
seed=self.args.seed,
drop_last=True,
Expand Down Expand Up @@ -235,7 +244,7 @@ def _prepare_dataloader(
# TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation
# if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e.,
# slice each batch along the sequence dimension).
if self.args.sequence_parallel_degree > 1:
if self.args.context_parallel_size > 1:
return dataloader

# Otherwise prepare with accelerator
Expand Down Expand Up @@ -308,18 +317,18 @@ def _generate_and_score_completions(
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
all_prompts_text = gather_object(prompts_text)
if self.accelerator.is_main_process:
if self.args.sequence_parallel_degree > 1:
if self.args.context_parallel_size > 1:
# Calculate sequence parallel group information
world_size = self.accelerator.num_processes
sequence_parallel_degree = self.args.sequence_parallel_degree
num_sp_groups = world_size // sequence_parallel_degree
context_parallel_size = self.args.context_parallel_size
num_sp_groups = world_size // context_parallel_size

# Since processes in the same SP group have the same prompts, we need to ensure
# we only take one copy of each prompt from each SP group
ordered_set_of_prompts = []
for sp_group_id in range(num_sp_groups):
# Get the first process from each SP group (typically the group leader)
group_leader_rank = sp_group_id * sequence_parallel_degree
group_leader_rank = sp_group_id * context_parallel_size

# Extract prompts from this SP group, accounting for num_generations duplicates
# We only need prompts from one rank in each SP group
Expand All @@ -335,7 +344,7 @@ def _generate_and_score_completions(
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
# prompt individually.
ordered_set_of_prompts = all_prompts_text[
:: self.num_generations * self.args.sequence_parallel_degree
:: self.num_generations * self.args.context_parallel_size
]

with profiling_context(self, "vLLM.generate"):
Expand All @@ -352,14 +361,14 @@ def _generate_and_score_completions(
)
else:
completion_ids = [None] * (
len(all_prompts_text) // self.args.sequence_parallel_degree
len(all_prompts_text) // self.args.context_parallel_size
)

# Broadcast the completions from the main process to all processes
completion_ids = broadcast_object_list(completion_ids, from_process=0)

# Determine the appropriate slice based on sequence parallelism
if self.args.sequence_parallel_degree > 1:
if self.args.context_parallel_size > 1:
# Calculate SP group ID (which group of ranks this rank belongs to)
sp_group_id = self.accelerator.process_index // self.local_world_size

Expand Down Expand Up @@ -583,7 +592,7 @@ def _generate_and_score_completions(
advantages = advantages / (std_grouped_rewards + 1e-4)

# Slice to keep only the local part of the data
if self.args.sequence_parallel_degree > 1:
if self.args.context_parallel_size > 1:
# Calculate SP group ID (which group of ranks this rank belongs to)
sp_group_id = self.accelerator.process_index // self.local_world_size

Expand Down
1 change: 1 addition & 0 deletions src/axolotl/core/trainers/mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from axolotl.core.trainers.base import AxolotlTrainer


# pylint: disable=too-many-ancestors
class AxolotlMambaTrainer(AxolotlTrainer):
"""Mamba specific trainer to handle loss calculation"""

Expand Down
Loading
Loading