Skip to content
Merged
Show file tree
Hide file tree
Changes from 46 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
7b5d8d3
Second version of degub/deterinistic configs.
githubsgi Oct 7, 2025
48f0e6b
Review relaeted updates.
githubsgi Oct 10, 2025
11250fa
Review 2 related changes.
githubsgi Oct 10, 2025
68ec57d
[DSV3] Offload dequantization process to DCP QuantizedHFReader (#1804)
wwwjn Oct 8, 2025
01df7d0
Disable FlexAttention max-autotune when deterministic is used (#1808)
fegin Oct 8, 2025
1e5fa06
Fix num of layers for deepseek-v3 (#1845)
wwwjn Oct 9, 2025
72620cc
[VLM] Add token-imbalance loss (#1803)
lkhphuc Oct 9, 2025
91bf469
refactor TrainSpec to remove the name field (#1850)
tianyu-l Oct 10, 2025
18185d7
Refactor attention and make attention mask an argument to the model (…
fegin Oct 10, 2025
c48a430
minor refactor over EP (#1854)
tianyu-l Oct 12, 2025
66ca021
Graduate qwen3 from experiment to core (#1860)
wwwjn Oct 13, 2025
1af8113
Review related updates.
githubsgi Oct 10, 2025
4e71f05
Rebasing and adding MATH attention kernel.
githubsgi Oct 13, 2025
6846bc3
Indent issue fix.
githubsgi Oct 13, 2025
1f86c98
Removing ipex.
githubsgi Oct 13, 2025
f8f98c5
Review updates.
githubsgi Oct 14, 2025
e1d2117
Fixing linter error.
githubsgi Oct 14, 2025
83e0f8d
graduate llama4 to core (#1865)
tianyu-l Oct 14, 2025
58d5fb8
consolidate experiments/deepseek_v3 (#1869)
tianyu-l Oct 14, 2025
a645315
add auto_eager_graph_pass (#1813)
ruisizhang123 Oct 14, 2025
abf961f
Second version of degub/deterinistic configs.
githubsgi Oct 7, 2025
4ee1b41
Review relaeted updates.
githubsgi Oct 10, 2025
59bf2a7
Review 2 related changes.
githubsgi Oct 15, 2025
3962362
[DSV3] Offload dequantization process to DCP QuantizedHFReader (#1804)
wwwjn Oct 8, 2025
21e38c0
Disable FlexAttention max-autotune when deterministic is used (#1808)
fegin Oct 8, 2025
6d350f3
[VLM] Add token-imbalance loss (#1803)
lkhphuc Oct 9, 2025
55a58c4
refactor TrainSpec to remove the name field (#1850)
tianyu-l Oct 10, 2025
872d88c
Refactor attention and make attention mask an argument to the model (…
fegin Oct 10, 2025
3ba6274
add script to train with ft (#1812)
tushar00jain Oct 10, 2025
c2d4822
Indent issue fix.
githubsgi Oct 13, 2025
63a2e93
Post rebase changes.
githubsgi Oct 15, 2025
9ff013c
Second version of degub/deterinistic configs.
githubsgi Oct 7, 2025
0b712a6
Review relaeted updates.
githubsgi Oct 10, 2025
d209acc
Review 2 related changes.
githubsgi Oct 10, 2025
93f6513
add script to train with ft (#1812)
tushar00jain Oct 10, 2025
0ebc9cf
minor refactor over EP (#1854)
tianyu-l Oct 12, 2025
aaabc2c
[vlm] Add light-weight CI for experimental models (#1848)
wwwjn Oct 12, 2025
00cad5c
add owners and CI status for experiments (#1859)
tianyu-l Oct 13, 2025
ac71851
move PP API to model agnostic file (#1868)
tianyu-l Oct 14, 2025
c976c97
[refactor] graduate custom_config_module and unify args/config naming…
tianyu-l Oct 14, 2025
0e99bea
Rebase misses.
githubsgi Oct 15, 2025
7ef61b7
Rebase mistakes.
githubsgi Oct 15, 2025
646ce75
Lint'er error fixes.
githubsgi Oct 16, 2025
b0dd8bd
Review updates.
githubsgi Oct 22, 2025
3243f58
Review unpdates.
githubsgi Oct 24, 2025
35a11f1
ufmt fixes.
githubsgi Oct 29, 2025
4f5f606
Update scripts/generate/test_generate.py
githubsgi Oct 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions docs/debugging.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ When debugging issues with multi-dimensional parallelism (combinations of FSDP,
Set consistent random seeds across all parallelism dimensions:

```bash
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --training.seed 42
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --debug.seed 42
```

**Seed behavior with parallelism:**
Expand All @@ -84,7 +84,7 @@ CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_tr
Enable deterministic algorithms to ensure bit-for-bit reproducibility across runs:

```bash
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --training.deterministic
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --debug.deterministic
```

**What it does:**
Expand All @@ -93,6 +93,19 @@ CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_tr
- Sets deterministic workspace configuration for CuBLAS operations
- **Note:** This will significantly reduce training performance but ensures exact reproducibility

Use `--debug.deterministic_warn_only` to only warn about (not stop running) kernel without deterministic implementation.

### Activation Checkipointing Debugging ###

The following debug configs are available for AC.

`preserve_rng_state` - if deterministic output compared to non-checkpointed passes is required, set to true. Results in stashing and restoring the RNG state during each checkpoint, may be slower.

`determinism_check` - A string specifying the determinism function

`debug` - capture ac debug information. Will be slower.

See https://docs.pytorch.org/docs/stable/checkpoint.html for details.

### Seed-Checkpoint-based Reproducibility

Expand Down
6 changes: 4 additions & 2 deletions scripts/generate/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
RowwiseParallel,
)
from torchtitan.components.metrics import build_device_memory_monitor
from torchtitan.config import ConfigManager
from torchtitan.config import ConfigManager, Debug as DebugConfig
from torchtitan.distributed import ParallelDims, utils as dist_utils
from torchtitan.protocols.train_spec import get_train_spec
from torchtitan.tools import utils
Expand Down Expand Up @@ -133,7 +133,9 @@ def test_generate(
# sequences would require https://github.com/pytorch/torchtitan/pull/686
apply_tp_minus_sp(model, parallel_dims.world_mesh["tp"])

dist_utils.set_determinism(world_mesh, device, seed, deterministic)
debug_config = DebugConfig()
debug_config.deterministic = deterministic
dist_utils.set_determinism(world_mesh, device, debug_config, seed)

# materalize model
model.to_empty(device=device_type)
Expand Down
2 changes: 2 additions & 0 deletions torchtitan/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
ActivationCheckpoint,
Checkpoint,
Comm,
Debug,
FaultTolerance,
Job,
JobConfig,
Expand Down Expand Up @@ -49,4 +50,5 @@
"Profiling",
"Training",
"Validation",
"Debug",
]
45 changes: 36 additions & 9 deletions torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,15 +263,6 @@ class Training:
many temporary files.
"""

seed: int | None = None
"""Choose the base RNG seed used for training"""

deterministic: bool = False
"""Use deterministic algorithms wherever possible, may be slower"""

debug_moe_force_load_balance: bool = False
"""If True, we force each experts to get the same amount of tokens via round-robin. This option is for debugging usage only."""


@dataclass
class Parallelism:
Expand Down Expand Up @@ -639,6 +630,26 @@ class ActivationCheckpoint:
https://github.com/pytorch/pytorch/pull/126320#discussion_r1625104015
"""

preserve_rng_state: bool = False
"""
If deterministic output compared to non-checkpointed passes is required, set
to true. Results in stashing and restoring the RNG state during each checkpoint,
may be slower. See https://docs.pytorch.org/docs/stable/checkpoint.html
for details.
"""

determinism_check: str = "default"
"""
A string specifying the determinism function. See
https://docs.pytorch.org/docs/stable/checkpoint.html for details.
"""

debug: bool = False
"""
Capture ac debug information. Will be slower. See
https://docs.pytorch.org/docs/stable/checkpoint.html for details.
"""


@dataclass
class Compile:
Expand Down Expand Up @@ -887,6 +898,21 @@ def __post_init__(self):
), "validation steps must be positive or -1"


@dataclass
class Debug:
seed: int | None = None
"""Choose the base RNG seed used for training"""

deterministic: bool = False
"""Use deterministic algorithms wherever possible, may be slower"""

deterministic_warn_only: bool = False
"""Only warns about ops without deterministic implementations rather than erroring out """

moe_force_load_balance: bool = False
"""If True, we force each experts to get the same amount of tokens via round-robin. This option is for debugging usage only."""


@dataclass
class JobConfig:
"""
Expand All @@ -912,6 +938,7 @@ class JobConfig:
fault_tolerance: FaultTolerance = field(default_factory=FaultTolerance)
experimental: Experimental = field(default_factory=Experimental)
validation: Validation = field(default_factory=Validation)
debug: Debug = field(default_factory=Debug)

def to_dict(self) -> dict[str, Any]:
return asdict(self)
Expand Down
19 changes: 14 additions & 5 deletions torchtitan/distributed/activation_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@ def _apply_layer_sac(module: nn.Module, ac_config: ACConfig) -> nn.Module:
ac_freq = int(ac_config.selective_ac_option)
if not ac_freq or _layer_sac_count % ac_freq == 0:
return ptd_checkpoint_wrapper(
module, preserve_rng_state=False, early_stop=ac_config.early_stop
module,
preserve_rng_state=ac_config.preserve_rng_state,
determinism_check=ac_config.determinism_check,
early_stop=ac_config.early_stop,
debug=ac_config.debug,
)
else:
return module
Expand Down Expand Up @@ -125,8 +129,10 @@ def selective_checkpointing_context_fn():
return ptd_checkpoint_wrapper(
module,
context_fn=selective_checkpointing_context_fn,
preserve_rng_state=False,
preserve_rng_state=ac_config.preserve_rng_state,
determinism_check=ac_config.determinism_check,
early_stop=ac_config.early_stop,
debug=ac_config.debug,
)


Expand All @@ -141,7 +147,11 @@ def _apply_full_ac(module: nn.Module, ac_config: ACConfig) -> nn.Module:
nn.Module: The module with full activation checkpointing applied.
"""
return ptd_checkpoint_wrapper(
module, preserve_rng_state=False, early_stop=ac_config.early_stop
module,
preserve_rng_state=ac_config.preserve_rng_state,
determinism_check=ac_config.determinism_check,
early_stop=ac_config.early_stop,
debug=ac_config.debug,
)


Expand All @@ -157,7 +167,7 @@ def _apply_op_sac_to_transformer_block_with_flex(

Args:
module (nn.Module): The transformer block to apply SAC to.
ac_config (ACConfig): The activation checkpointing config.
ac_config (ACConfig): The Activation Checkpoint config.
base_fqn (str, optional): The base fqn of the module. Defaults to None.
model_compile_enabled (bool): Whether model compilation is enabled.
Defaults to False.
Expand Down Expand Up @@ -298,7 +308,6 @@ def apply_ac(
Returns:
None
"""

if ac_config.mode == "memory_budget":
assert model_compile_enabled, "Memory budget mode requires model to be compiled"
if ac_config.visualize_memory_budget_pareto:
Expand Down
11 changes: 7 additions & 4 deletions torchtitan/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor import DTensor

from torchtitan.config import Comm as CommConfig, TORCH_DTYPE_MAP
from torchtitan.config import Comm as CommConfig, Debug as DebugConfig, TORCH_DTYPE_MAP
from torchtitan.distributed.parallel_dims import ParallelDims
from torchtitan.tools.logging import logger
from torchtitan.tools.utils import device_module, device_type
Expand Down Expand Up @@ -83,8 +83,7 @@ def dist_mean(
def set_determinism(
world_mesh: DeviceMesh | None,
device: torch.device,
seed: int | None = None,
deterministic: bool = False,
debug_config: DebugConfig,
distinct_seed_mesh_dim: str = "pp",
) -> None:
"""
Expand All @@ -97,9 +96,12 @@ def set_determinism(

Set Determinism flags for increased reproducibility with loss of performance.
"""
if deterministic:
if debug_config.deterministic:
logger.info("Deterministic algorithm enabled (expect perf degradation).")
torch.use_deterministic_algorithms(True)
torch.use_deterministic_algorithms(
True, warn_only=debug_config.deterministic_warn_only
)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# env var for deterministic CuBLAS
Expand All @@ -114,6 +116,7 @@ def set_determinism(

FlexAttentionWrapper._compiled_flex_attn = torch.compile(flex_attention)

seed = debug_config.seed
if not world_mesh:
if seed is not None:
torch.manual_seed(seed)
Expand Down
3 changes: 1 addition & 2 deletions torchtitan/experiments/forge/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,7 @@ def __init__(self, job_config: ForgeJobConfig):
dist_utils.set_determinism(
world_mesh,
self.device,
job_config.training.seed,
job_config.training.deterministic,
job_config.debug,
)
self.train_spec = get_train_spec(job_config.model.name)

Expand Down
2 changes: 2 additions & 0 deletions torchtitan/experiments/forge/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Checkpoint,
Comm,
Compile,
Debug,
Job,
LRScheduler,
MemoryEstimation,
Expand Down Expand Up @@ -45,6 +46,7 @@ class ForgeJobConfig:
# fault_tolerance: FaultTolerance = field(default_factory=FaultTolerance)
# experimental: Experimental = field(default_factory=Experimental)
# validation: Validation = field(default_factory=Validation)
debug: Debug = field(default_factory=Debug)

def to_dict(self) -> dict[str, Any]:
return asdict(self)
1 change: 1 addition & 0 deletions torchtitan/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def __init__(self) -> None:
SDPBackend.CUDNN_ATTENTION,
SDPBackend.FLASH_ATTENTION,
SDPBackend.EFFICIENT_ATTENTION,
SDPBackend.MATH,
]

def forward(
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/models/deepseek_v3/model/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None:
)

self.moe_args._debug_force_load_balance = (
job_config.training.debug_moe_force_load_balance
job_config.debug.moe_force_load_balance
)

def get_nparams_and_flops(
Expand Down
3 changes: 1 addition & 2 deletions torchtitan/models/flux/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ def __init__(self, job_config: JobConfig):
dist_utils.set_determinism(
self.parallel_dims.world_mesh,
self.device,
job_config.training.seed,
job_config.training.deterministic,
job_config.debug,
distinct_seed_mesh_dim="dp_shard",
)

Expand Down
2 changes: 1 addition & 1 deletion torchtitan/models/llama4/model/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None:
)

self.moe_args._debug_force_load_balance = (
job_config.training.debug_moe_force_load_balance
job_config.debug.moe_force_load_balance
)

def get_nparams_and_flops(
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/models/qwen3/model/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None:
self.max_seq_len = seq_len

self.moe_args._debug_force_load_balance = (
job_config.training.debug_moe_force_load_balance
job_config.debug.moe_force_load_balance
)

def get_nparams_and_flops(
Expand Down
4 changes: 2 additions & 2 deletions torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import Any, Generator, Iterable, Optional

import torch

from torch.distributed.elastic.multiprocessing.errors import record

import torchtitan.protocols.train_spec as train_spec_module
Expand Down Expand Up @@ -118,8 +119,7 @@ def __init__(self, job_config: JobConfig):
dist_utils.set_determinism(
world_mesh,
self.device,
job_config.training.seed,
job_config.training.deterministic,
job_config.debug,
)
self.train_spec = train_spec_module.get_train_spec(job_config.model.name)

Expand Down