Skip to content

Add refactored recipe files for pretrain configs of LLMs#2067

Merged
yaoyu-33 merged 28 commits intomainfrom
athitten/recipe_refactor
Feb 2, 2026
Merged

Add refactored recipe files for pretrain configs of LLMs#2067
yaoyu-33 merged 28 commits intomainfrom
athitten/recipe_refactor

Conversation

@athitten
Copy link
Copy Markdown
Contributor

@athitten athitten commented Jan 26, 2026

What does this PR do ?

Add a one line overview of what this PR aims to accomplish.

Changelog

  • Add specific line by line info of high level changes in this PR.

GitHub Actions CI

See the CI sectionin the Contributing doc for how to trigger the CI. A Nvidia developer will need to approve and trigger the CI for external contributors.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

If you haven't finished some of the above items you can still open "Draft" PR.

Additional Information

  • Related to # (issue)

Summary by CodeRabbit

Release Notes

  • Refactor
    • Simplified pretraining configuration APIs across all model recipes (DeepSeek, Gemma, GLM, GPT, Llama, Qwen, and others). Configuration functions now follow a consistent, parameterless pattern returning fully populated configurations instead of requiring custom arguments, reducing complexity and improving consistency across model variants.

✏️ Tip: You can customize this high-level summary in your review settings.

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Jan 26, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@athitten athitten force-pushed the athitten/recipe_refactor branch from 8449d49 to 6734170 Compare January 26, 2026 18:31
@athitten
Copy link
Copy Markdown
Contributor Author

/ok to test cfc6df8

@athitten
Copy link
Copy Markdown
Contributor Author

/ok to test d7cfa85

@athitten
Copy link
Copy Markdown
Contributor Author

/ok to test fb8938a

@athitten athitten marked this pull request as ready for review January 28, 2026 06:57
@athitten
Copy link
Copy Markdown
Contributor Author

/ok to test ab9c99c

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Jan 28, 2026

📝 Walkthrough

Walkthrough

This PR introduces a centralized pretraining configuration helper (_pretrain_common()) and refactors 15+ recipe files to use it instead of accepting user-supplied kwargs. Each model's pretrain config function now builds configurations declaratively by calling _pretrain_common() and explicitly setting fields, replacing the previous pattern of multi-parameter common helpers.

Changes

Cohort / File(s) Summary
New shared utility
src/megatron/bridge/recipes/common.py
Introduces _pretrain_common() helper that returns a base ConfigContainer pre-populated with sensible defaults (optimizers, DDP config, dataset, logging, tokenizer, checkpoint, RNG, mixed precision) for language-model pretraining. Model and tokenizer_model must be set by caller.
DeepSeek variants
src/megatron/bridge/recipes/deepseek/deepseek_v2.py, deepseek_v3.py
Removed kwargs-driven API and _deepseek_common() helpers. Added deepseek_v2_lite_pretrain_config() and deepseek_v2_pretrain_config() (V2), plus deepseek_v3_pretrain_config_32nodes() variant (V3). All now use _pretrain_common() with explicit model, MoE, training, and optimizer field assignments.
Gemma variants
src/megatron/bridge/recipes/gemma/gemma2.py, gemma3.py
Removed user kwargs and _gemma2_common()/_gemma3_common() helpers. Updated gemma2_2b/9b/27b_pretrain_config() and gemma3_1b_pretrain_config() to parameterless functions using _pretrain_common() with explicit field setup.
GLM4.5 variants
src/megatron/bridge/recipes/glm/glm45.py
Removed kwargs-driven _glm45_common(). Updated glm45_355b_pretrain_config() and glm45_air_106b_pretrain_config() to use _pretrain_common() with explicit MoE, parallelism, and training configs.
GPT-3 and GPT-OSS
src/megatron/bridge/recipes/gpt/gpt3_175b.py, gpt_oss/gpt_oss.py
Removed parameterized pretrain_config() and model builders. Added gpt3_175b_pretrain_config() and updated gpt_oss_20b/120b_pretrain_config() to use _pretrain_common() with explicit defaults (TP=4, PP=8, VP=6 for GPT-3).
Kimi and Moonlight
src/megatron/bridge/recipes/kimi/kimi_k2.py, moonlight/moonlight_16b.py
Removed kwargs-driven APIs and _kimi_k2_common()/_moonlight_common() helpers. Added _get_kimi_k2_pipeline_layout() and _get_moonlight_pipeline_layout() helper functions. Updated kimi_k2_pretrain_config() and moonlight_16b_pretrain_config() to use _pretrain_common().
Llama variants
src/megatron/bridge/recipes/llama/llama2.py, llama3.py
Removed Llama2CommonKwargs/Llama3CommonKwargs TypedDicts and _llama2_common()/_llama3_common() helpers. Updated 12 Llama3/Llama3.1 pretrain functions to parameterless signatures using _pretrain_common() with HF-based model initialization and NullTokenizer defaults.
Nemotron variants
src/megatron/bridge/recipes/nemotronh/nemotron_nano_v2.py, nemotronh.py
Removed kwargs-driven _nemotron_common() helpers. Updated nemotron_nano_9b/12b_v2_pretrain_config() and nemotronh_4b/8b/47b/56b_pretrain_config() to use _pretrain_common() with explicit model, tokenizer, and DDP settings.
OLMoE and other MoE variants
src/megatron/bridge/recipes/olmoe/olmoe_7b.py
Removed kwargs-driven _olmoe_common() and _model_config(). Added _get_olmoe_pipeline_layout() helper. Updated olmoe_7b_pretrain_config() to use _pretrain_common() with explicit MoE, layout, and precision configs.
Qwen variants
src/megatron/bridge/recipes/qwen/qwen2.py, qwen3.py, qwen3_moe.py, qwen3_next.py
Removed Qwen*CommonKwargs TypedDicts and _qwen*_common() helpers. Updated 15+ pretrain functions (qwen2_500m/1p5b/7b/72b, qwen25_*, qwen3_*, qwen3_30b_a3b, qwen3_235b_a22b, qwen3_next_80b_a3b) to parameterless signatures using _pretrain_common() with AutoBridge-based model initialization.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

The refactoring is systematic and pattern-consistent across 15+ files, but requires careful verification that default configurations are sensible, no critical settings are lost in the transition from kwargs to explicit field assignments, and the _pretrain_common() baseline provides appropriate defaults for diverse model architectures (various parameter counts, parallelism strategies, MoE configurations, and precision settings). The heterogeneity of affected files and model-specific configurations demands thorough cross-file consistency checks.

Suggested reviewers

  • cuichenx
  • yaoyu-33
  • skyw
🚥 Pre-merge checks | ✅ 3 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Test Results For Major Changes ⚠️ Warning PR represents major refactoring of pretraining configurations across 20+ files with 6,000+ lines modified, lacking comprehensive testing documentation, convergence verification, and resolution of identified configuration errors. Add validation tests confirming identical model architectures, conduct training convergence tests per architecture family, provide configuration validation evidence, and resolve identified review issues before approval.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The PR title 'Add refactored recipe files for pretrain configs of LLMs' accurately summarizes the main change—refactoring recipe files for pretraining configurations across multiple large language models.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@athitten
Copy link
Copy Markdown
Contributor Author

/ok to test 1885a66

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
src/megatron/bridge/recipes/nemotronh/nemotron_nano_v2.py (1)

19-23: Duplicate import of NemotronNanoModelProvider9Bv2.

The same class is imported twice:

  • Line 19: from megatron.bridge.models import NemotronNanoModelProvider9Bv2
  • Lines 20-22: from megatron.bridge.models.nemotronh import (NemotronNanoModelProvider9Bv2, ...)

Remove the redundant import on line 19.

Proposed fix
-from megatron.bridge.models import NemotronNanoModelProvider9Bv2
 from megatron.bridge.models.nemotronh import (
     NemotronNanoModelProvider9Bv2,
     NemotronNanoModelProvider12Bv2,
 )
src/megatron/bridge/recipes/nemotronh/nemotronh.py (1)

520-532: Incorrect TypedDict annotation for recommended_kwargs.

The recommended_kwargs dictionary is annotated as NemotronHFinetuneKwargs, which defines keys like tensor_model_parallel_size, pipeline_model_parallel_size, and sequence_parallel (inherited from NemotronHCommonKwargs). However, the dictionary uses keys tensor_parallelism, pipeline_parallelism, and sequence_parallelism instead. While these keys match the function parameter names for _nemotronh_finetune_common, they violate the TypedDict contract and will cause type checking errors in mypy or pyright. Either rename the dictionary keys to match the TypedDict field names, or adjust the type annotation.

🤖 Fix all issues with AI agents
In `@src/megatron/bridge/recipes/gpt_oss/gpt_oss.py`:
- Around line 152-157: The config currently leaves cfg.model.pipeline_dtype as
None while cfg.model.pipeline_model_parallel_size is >1; update both occurrences
of cfg.model.pipeline_dtype in the GPT-OSS recipe to be torch.bfloat16 when
cfg.model.pipeline_model_parallel_size > 1 (i.e., set cfg.model.pipeline_dtype =
torch.bfloat16), matching other PP>1 configs and ensuring this change is applied
to both places where pipeline_dtype is currently None.

In `@src/megatron/bridge/recipes/llama/llama3.py`:
- Around line 1013-1017: The CommOverlapConfig for cfg.comm_overlap currently
references userbuffers_bf16_h100_h8192_tp4_mbs1_seqlen8192 while the model
parallelism (TP) for this recipe is TP=8; update the comm overlap config to the
matching TP=8 variant (e.g. replace
userbuffers_bf16_h100_h8192_tp4_mbs1_seqlen8192 with the TP=8 equivalent such as
userbuffers_bf16_h100_h8192_tp8_mbs1_seqlen8192) so
CommOverlapConfig(tp_comm_overlap=True, tp_comm_overlap_cfg=...) matches the
actual TP setting.
- Around line 659-663: Replace the runtime-checking assert with an explicit
exception: instead of using assert mixed_precision_recipe in [...] in llama3.py,
raise a ValueError when mixed_precision_recipe is not in the allowed list
(include the same descriptive message, e.g. f"Invalid low precision recipe:
{mixed_precision_recipe}. This recipe has not been tested yet."). Update the
validation near the mixed_precision_recipe variable so it always runs (not
skipped by -O) by using if mixed_precision_recipe not in [...] : raise
ValueError(...).
🧹 Nitpick comments (13)
src/megatron/bridge/recipes/qwen/qwen3_next.py (1)

389-391: Potential runtime error when no CUDA device is available.

The check torch.cuda.get_device_properties(0).major == 10 will raise a RuntimeError if no CUDA device is present. While this is a training config where CUDA is expected, consider adding a guard:

     if disable_jit_fuser is None:
-        disable_jit_fuser = torch.cuda.get_device_properties(0).major == 10
+        disable_jit_fuser = torch.cuda.is_available() and torch.cuda.get_device_properties(0).major == 10

This would make the config generation more robust for dry-run or validation scenarios.

src/megatron/bridge/recipes/gemma/gemma2.py (2)

17-17: Consider using modern type annotation syntax.

Per coding guidelines, prefer:

  • list[str] instead of List[str]
  • str | None instead of Optional[str]
  • MixedPrecisionConfig | str instead of Union[MixedPrecisionConfig, str]

The imports on line 17 use the older typing module equivalents.

Proposed fix
-from typing import List, Optional, Union

Then update type hints throughout the file (e.g., in Gemma2CommonKwargs and Gemma2FinetuneKwargs classes) to use modern syntax like list[str] | None.


117-362: Significant code duplication across Gemma2 pretrain configs.

The three pretrain functions (gemma2_2b_pretrain_config, gemma2_9b_pretrain_config, gemma2_27b_pretrain_config) share ~95% identical code, differing only in:

  • HuggingFace model path
  • tensor_model_parallel_size (2, 8, 8)
  • pipeline_model_parallel_size (1, 1, 2)
  • pipeline_dtype (None, bfloat16, bfloat16)

Consider extracting a shared helper to reduce maintenance burden:

def _gemma2_pretrain_base(
    hf_path: str,
    tensor_model_parallel_size: int,
    pipeline_model_parallel_size: int,
    pipeline_dtype: torch.dtype | None,
) -> ConfigContainer:
    cfg = _pretrain_common()
    cfg.model = AutoBridge.from_hf_pretrained(hf_path).to_megatron_provider(load_weights=False)
    cfg.tokenizer.tokenizer_model = hf_path
    # ... common settings ...
    cfg.model.tensor_model_parallel_size = tensor_model_parallel_size
    cfg.model.pipeline_model_parallel_size = pipeline_model_parallel_size
    cfg.model.pipeline_dtype = pipeline_dtype
    return cfg

This would make each variant a simple 5-line wrapper. However, the current explicit approach does have the benefit of making each recipe self-contained and easy to customize independently.

src/megatron/bridge/recipes/qwen/qwen3_moe.py (1)

16-16: Same typing style suggestion as gemma2.py.

Consider updating to modern type annotation syntax (list[str] | None instead of Optional[List[str]]).

src/megatron/bridge/recipes/moonlight/moonlight_16b.py (1)

130-148: Simplify the list copy expression.

The expression list([list(x) for x in layout]) is redundant—the outer list() call is unnecessary since the list comprehension already produces a list.

Suggested fix
-    if layout is not None:
-        layout = list([list(x) for x in layout])
+    if layout is not None:
+        layout = [list(x) for x in layout]
src/megatron/bridge/recipes/glm/glm45.py (1)

267-382: Significant duplication with glm45_355b_pretrain_config.

The glm45_air_106b_pretrain_config shares ~90% of its code with glm45_355b_pretrain_config. Consider extracting the common configuration logic into a private helper to reduce maintenance burden. The key differences are:

  • HF path: "zai-org/GLM-4.5-Air" vs "zai-org/GLM-4.5"
  • Parallelism: TP=1, PP=4, EP=8 vs TP=2, PP=8, EP=16
src/megatron/bridge/recipes/kimi/kimi_k2.py (1)

43-44: Simplify the list copy expression.

Same redundant pattern as noted in moonlight_16b.py.

Suggested fix
     if layout is not None:
-        layout = list([list(x) for x in layout])
+        layout = [list(x) for x in layout]
src/megatron/bridge/recipes/olmoe/olmoe_7b.py (1)

128-145: Pipeline layout helper is well-documented.

The comment noting "OLMoE has 16 layers" helps understand the layout mappings. Consider simplifying the list copy:

Suggested fix
     if layout is not None:
-        layout = list([list(x) for x in layout])
+        layout = [list(x) for x in layout]
src/megatron/bridge/recipes/qwen/qwen2.py (1)

516-927: Remaining Qwen2.5 pretrain configs follow consistent pattern.

The configs for 1.5B, 7B, 14B, 32B, and 72B all follow the same structure with appropriate parallelism scaling. The 32B and 72B configs correctly set pipeline_dtype=torch.bfloat16 for PP > 1.

Consider extracting common configuration logic.

There's significant duplication across all 11 pretrain configs. A private helper like _qwen2_pretrain_common(hf_path, tp, pp, ...) could reduce ~800 lines to ~200 while maintaining explicit per-variant entry points.

src/megatron/bridge/recipes/llama/llama3.py (4)

43-88: Llama3CommonKwargs appears to be unused after the refactor.

The pretrain config functions are now parameterless and no longer accept **kwargs. This TypedDict is dead code that should be removed to avoid confusion.

♻️ Suggested fix
-class Llama3CommonKwargs(TypedDict, total=False):
-    """Typed options accepted by Llama3 family recipe helpers."""
-
-    # Core identifiers
-    hf_path: str
-    dir: str | None
-    name: str
-    # Dataset configuration
-    data_paths: list[str] | None
-    data_args_path: str | None
-    train_data_path: list[str] | None
-    valid_data_path: list[str] | None
-    test_data_path: list[str] | None
-    per_split_data_args_path: str | None
-    mock: bool
-    # Model configuration
-    tensor_model_parallel_size: int
-    pipeline_model_parallel_size: int
-    pipeline_dtype: torch.dtype | None
-    virtual_pipeline_model_parallel_size: int | None
-    context_parallel_size: int
-    sequence_parallel: bool
-    use_megatron_fsdp: bool
-    account_for_embedding_in_pipeline_split: bool
-    account_for_loss_in_pipeline_split: bool
-    # Training hyperparameters
-    train_iters: int
-    global_batch_size: int
-    micro_batch_size: int
-    seq_length: int
-    lr: float
-    min_lr: float
-    adam_eps: float
-    lr_warmup_iters: int
-    lr_decay_iters: int | None
-    eval_interval: int
-    save_interval: int
-    use_null_tokenizer: bool
-    # W&B logging
-    wandb_project: str | None
-    wandb_entity: str | None
-    wandb_exp_name: str | None
-    # Precision / overlap configs
-    precision_config: MixedPrecisionConfig | str | None
-    comm_overlap_config: CommOverlapConfig | None
-
-

140-234: Significant code duplication across all pretrain config functions.

All 14 pretrain config functions share ~50 identical lines (transformer_impl, cuda_graph settings, kernel selections, memory saving, optimizer precision, DDP config). Only the HF path, parallelism settings, and seq_length vary. Consider extracting common post-setup into a helper or using a builder pattern.

♻️ Example approach - extract common settings to a helper
def _apply_common_pretrain_settings(cfg: ConfigContainer, seq_length: int = 8192) -> None:
    """Apply common settings shared by all Llama pretrain configs."""
    # Tokenizer - NullTokenizer by default
    cfg.tokenizer.tokenizer_type = "NullTokenizer"
    cfg.tokenizer.tokenizer_model = None
    cfg.tokenizer.vocab_size = DEFAULT_NULL_TOKENIZER_VOCAB_SIZE

    # Dataset
    cfg.dataset.blend = None
    cfg.dataset.num_workers = 8
    cfg.dataset.seq_length = seq_length

    # Training
    cfg.train.train_iters = 1168251
    cfg.train.global_batch_size = 512
    cfg.train.micro_batch_size = 1
    cfg.train.eval_interval = 2000
    cfg.train.manual_gc = True
    cfg.train.manual_gc_interval = 100

    cfg.scheduler.lr_warmup_iters = 2000
    cfg.logger.log_timers_to_tensorboard = True

    # TE & CUDA Graph
    cfg.model.transformer_impl = "transformer_engine"
    cfg.model.cuda_graph_impl = "none"
    cfg.model.cuda_graph_scope = "full"
    cfg.model.cuda_graph_warmup_steps = 3

    # Kernel selections
    cfg.model.attention_backend = None
    cfg.model.cross_entropy_loss_fusion = True
    cfg.model.cross_entropy_fusion_impl = "te"

    # Memory saving
    cfg.model.recompute_granularity = None
    cfg.model.recompute_modules = None
    cfg.model.fine_grained_activation_offloading = False
    cfg.model.offload_modules = None

    # Optimizer precision
    cfg.optimizer.use_precision_aware_optimizer = False
    cfg.optimizer.main_grads_dtype = torch.float32
    cfg.optimizer.main_params_dtype = torch.float32
    cfg.optimizer.exp_avg_dtype = torch.float32
    cfg.optimizer.exp_avg_sq_dtype = torch.float32

    cfg.checkpoint.save_interval = 500

    # DDP
    cfg.ddp.overlap_grad_reduce = True
    cfg.ddp.overlap_param_gather = True
    cfg.ddp.check_for_nan_in_grad = True
    cfg.ddp.use_distributed_optimizer = True
    cfg.ddp.use_megatron_fsdp = False
    cfg.ddp.grad_reduce_in_fp32 = True
    cfg.ddp.average_in_collective = True
    cfg.ddp.data_parallel_sharding_strategy = "no_shard"


def llama32_1b_pretrain_config() -> ConfigContainer:
    """Return a pre-training config for Llama 3.2 1B."""
    cfg = _pretrain_common()
    cfg.model = AutoBridge.from_hf_pretrained("meta-llama/Llama-3.2-1B").to_megatron_provider(load_weights=False)
    
    _apply_common_pretrain_settings(cfg, seq_length=8192)
    
    # Model-specific parallelism
    cfg.model.tensor_model_parallel_size = 1
    cfg.model.pipeline_model_parallel_size = 1
    cfg.model.context_parallel_size = 1
    cfg.model.sequence_parallel = False
    cfg.model.seq_length = 8192
    
    return cfg

158-168: seq_length is set in two places - consider using a single source of truth.

Both cfg.dataset.seq_length and cfg.model.seq_length must be kept in sync. Consider defining a local constant to avoid divergence:

+    seq_length = 8192
+
     # Dataset config - mock data by default
     cfg.dataset.blend = None
     cfg.dataset.num_workers = 8
-    cfg.dataset.seq_length = 8192
+    cfg.dataset.seq_length = seq_length
     ...
-    cfg.model.seq_length = 8192
+    cfg.model.seq_length = seq_length

224-233: Several DDP settings redundantly override _pretrain_common() defaults.

Settings like overlap_grad_reduce, overlap_param_gather, check_for_nan_in_grad, grad_reduce_in_fp32, and average_in_collective are already set to the same values in _pretrain_common(). Only data_parallel_sharding_strategy and use_megatron_fsdp are actual overrides. Consider keeping only the intentional overrides to improve readability.

♻️ Keep only intentional overrides
     # DDP config
-    cfg.ddp.overlap_grad_reduce = True
-    cfg.ddp.overlap_param_gather = True
-    cfg.ddp.check_for_nan_in_grad = True
-    cfg.ddp.use_distributed_optimizer = True
+    # Override sharding strategy from _pretrain_common default
     cfg.ddp.use_megatron_fsdp = False
-    cfg.ddp.grad_reduce_in_fp32 = True
-    cfg.ddp.average_in_collective = True
     cfg.ddp.data_parallel_sharding_strategy = "no_shard"

Comment on lines +152 to +157
# Parallelism settings
cfg.model.tensor_model_parallel_size = 2
cfg.model.pipeline_model_parallel_size = 4
cfg.model.pipeline_model_parallel_layout = None
cfg.model.pipeline_dtype = None
cfg.model.virtual_pipeline_model_parallel_size = None
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot Jan 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, let's verify the file exists and check the exact lines mentioned
head -n 280 src/megatron/bridge/recipes/gpt_oss/gpt_oss.py | tail -n 130 | cat -n

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 5970


🏁 Script executed:

# Check for all pipeline_dtype mentions in recipes and training code
echo "=== pipeline_dtype in recipes ==="
rg -n "pipeline_dtype" src/megatron/bridge/recipes -g '*.py' -A 2

echo -e "\n=== pipeline_dtype in training ==="
rg -n "pipeline_dtype" src/megatron/bridge/training -g '*.py' -A 2

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 42328


🏁 Script executed:

# Find all PP>1 configs to compare pipeline_dtype settings
echo "=== PP>1 configs ==="
rg -n "pipeline_model_parallel_size\s*=\s*[2-9]" src/megatron/bridge/recipes -g '*.py' -B 5 -A 5

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 29932


Set pipeline_dtype to torch.bfloat16 for PP>1 in GPT-OSS configs.

Lines 156 and 271 set cfg.model.pipeline_dtype = None while pipeline_model_parallel_size = 4. Across the codebase, all other PP>1 configs (Qwen, Deepseek, Llama, Gemma) explicitly set pipeline_dtype to torch.bfloat16, with inline comments indicating it is required for PP > 1. This inconsistency can cause validation failures or unintended numerical changes.

Suggested fix
-    cfg.model.pipeline_dtype = None
+    cfg.model.pipeline_dtype = torch.bfloat16

Apply to both lines 156 and 271.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Parallelism settings
cfg.model.tensor_model_parallel_size = 2
cfg.model.pipeline_model_parallel_size = 4
cfg.model.pipeline_model_parallel_layout = None
cfg.model.pipeline_dtype = None
cfg.model.virtual_pipeline_model_parallel_size = None
# Parallelism settings
cfg.model.tensor_model_parallel_size = 2
cfg.model.pipeline_model_parallel_size = 4
cfg.model.pipeline_model_parallel_layout = None
cfg.model.pipeline_dtype = torch.bfloat16
cfg.model.virtual_pipeline_model_parallel_size = None
🤖 Prompt for AI Agents
In `@src/megatron/bridge/recipes/gpt_oss/gpt_oss.py` around lines 152 - 157, The
config currently leaves cfg.model.pipeline_dtype as None while
cfg.model.pipeline_model_parallel_size is >1; update both occurrences of
cfg.model.pipeline_dtype in the GPT-OSS recipe to be torch.bfloat16 when
cfg.model.pipeline_model_parallel_size > 1 (i.e., set cfg.model.pipeline_dtype =
torch.bfloat16), matching other PP>1 configs and ensuring this change is applied
to both places where pipeline_dtype is currently None.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GPTOSS does has cfg.model.pipeline_dtype set to None explicitly in the current config. Hence left it that way to match the old behavior.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check the current main branch to see what the old config had at line 178
git fetch origin main 2>/dev/null || true
git show origin/main:src/megatron/bridge/recipes/gpt_oss/gpt_oss.py | head -n 185 | tail -n 15

echo -e "\n=== Checking for pipeline_dtype usage in MoE models (PP>1) ==="
# Check other MoE models (OLMoE, DeepSeek, GLM, etc.) to see their pipeline_dtype patterns
rg -n "expert_model_parallel_size.*=" src/megatron/bridge/recipes -g '*.py' -A 10 -B 5 | rg -A 10 -B 5 "pipeline_dtype"

Length of output: 50024


@athitten, thank you for the clarification! You're right that the old _gpt_oss_common function had pipeline_dtype default to None, and preserving that behavior makes sense for backward compatibility.

I did notice that the codebase has mixed patterns for MoE models with PP>1:

  • Models using torch.bfloat16: DeepSeek v2/v3, Qwen3_next, Qwen3_moe, Kimi K2, with explicit comments like "Required for PP > 1"
  • Models using None: GLM45 (both 355B with PP=8 and Air 106B with PP=4) also use None with PP>1

Since GLM45 follows the same pattern as GPT-OSS, your approach is consistent with at least some of the MoE models in the codebase. If the old behavior worked correctly with None, then maintaining it is the right call for this refactor.

Just wanted to flag the inconsistency in case it's worth verifying whether None vs torch.bfloat16 has any performance or correctness implications for PP>1, but I understand you're prioritizing backward compatibility here. 👍


🧠 Learnings used
Learnt from: aroshanghias-nvd
Repo: NVIDIA-NeMo/Megatron-Bridge PR: 2040
File: src/megatron/bridge/models/mimo/llava_provider.py:109-113
Timestamp: 2026-01-28T17:39:44.596Z
Learning: In `src/megatron/bridge/models/mimo/llava_provider.py`, the `TransformerConfig` used for the `MultimodalProjector` includes `num_attention_heads=1` as a required placeholder. While the MLP projector doesn't use attention, `TransformerConfig` requires a positive value for `num_attention_heads` (default of 0 is not valid), so using 1 as a placeholder is the correct approach.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@athitten : can you help to update the pipeline dtype to bf16 actually, i think the old value is not good.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like the humans are having a chat. I'll hop back into my burrow for now. If you need me again, just tag @coderabbitai in a new comment, and I'll come hopping out!

Comment on lines +1013 to +1017
# Communication overlap for 70B
cfg.comm_overlap = CommOverlapConfig(
tp_comm_overlap=True,
tp_comm_overlap_cfg=userbuffers_bf16_h100_h8192_tp4_mbs1_seqlen8192,
)
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot Jan 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Same communication overlap mismatch as 70B 16K config.

This config also uses TP=8 (line 952) with the TP=4 comm overlap config.

🤖 Prompt for AI Agents
In `@src/megatron/bridge/recipes/llama/llama3.py` around lines 1013 - 1017, The
CommOverlapConfig for cfg.comm_overlap currently references
userbuffers_bf16_h100_h8192_tp4_mbs1_seqlen8192 while the model parallelism (TP)
for this recipe is TP=8; update the comm overlap config to the matching TP=8
variant (e.g. replace userbuffers_bf16_h100_h8192_tp4_mbs1_seqlen8192 with the
TP=8 equivalent such as userbuffers_bf16_h100_h8192_tp8_mbs1_seqlen8192) so
CommOverlapConfig(tp_comm_overlap=True, tp_comm_overlap_cfg=...) matches the
actual TP setting.

Copy link
Copy Markdown
Contributor Author

@athitten athitten Jan 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of them 70B, 70B 16k, 70B 64k use the same userbuffers_bf16_h100_h8192_tp4_mbs1_seqlen8192 comm overlap config. This matches with the existing behavior

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Search for all available userbuffers comm overlap configurations in the codebase
rg -n "^def userbuffers_" --type py -A 1

# Also check what's exported from comm_overlap module
rg -n "userbuffers_" src/megatron/bridge/training/comm_overlap.py -C 2

Length of output: 4930


@athitten, thank you for clarifying! I can see from the available comm overlap configurations that there is indeed no TP=8 variant for h8192 (70B hidden size) models on H100 - only the TP=4 variant userbuffers_bf16_h100_h8192_tp4_mbs1_seqlen8192 exists for BF16 on H100.

The available configurations show:

  • h8192 (70B): TP=4 and TP=2 (LoRA) variants only
  • h16384 (405B): TP=8 variant available

So it makes sense that all 70B configs (including the TP=8 variants for 16K and 64K) use the same TP=4 comm overlap config as the closest match. I'll note this as intentional behavior.


🧬 Code Graph Analysis Results

<src/megatron/bridge/recipes/llama/llama3.py> // Summary description of relevant snippet

  • Summary:

    • In the pretraining configurations for Llama 3 70B variants (e.g., llama3_70b_pretrain_config and llama3_70b_16k_pretrain_config), the code enables communication overlap by constructing a CommOverlapConfig with tensor-parallel overlap enabled and using a pre-defined TP overlap configuration.
    • The specific overlap configuration used is the pre-defined function/variable named userbuffers_bf16_h100_h8192_tp4_mbs1_seqlen8192, assigned to the tp_comm_overlap_cfg field of CommOverlapConfig.
    • This pattern is referenced by the user comment (70B, 70B 16k, 70B 64k all using the same userbuffers_bf16_h100_h8192_tp4_mbs1_seqlen8192 config) and is visible in the llama3.py pretrain config sections for 70B with CommOverlap enabled.

    Details to look for (conceptually):

    • Definition or import of CommOverlapConfig
    • Import/definition of userbuffers_bf16_h100_h8192_tp4_mbs1_seqlen8192
    • Instances of:
      • llama3_70b_pretrain_config() -> CommOverlapConfig(tp_comm_overlap=True, tp_comm_overlap_cfg=userbuffers_bf16_h100_h8192_tp4_mbs1_seqlen8192)
      • llama3_70b_16k_pretrain_config() -> similarly sets CommOverlapConfig with the same tp_comm_overlap_cfg

    Parameters and behavior (from context in file and related module):

    • tp_comm_overlap: boolean to enable tensor-parallel overlap
    • tp_comm_overlap_cfg: a specific configuration object/function (here: userbuffers_bf16_h100_h8192_tp4_mbs1_seqlen8192) describing overlap specifics (TP, MBS, sequence length)
    • This setup is intended to optimize comms for 70B variants with SP enabled and specific parallelism patterns.

    File/Location:

    • Source file: src/megatron/bridge/recipes/llama/llama3.py
    • Context: within pretrain config factory functions for 70B variants (e.g., llama3_70b_pretrain_config and llama3_70b_16k_pretrain_config)
    • Note: The exact line numbers are not provided here, but this description corresponds to the “Communication overlap for 70B” sections where CommOverlapConfig is assigned with tp_comm_overlap and tp_comm_overlap_cfg set to userbuffers_bf16_h100_h8192_tp4_mbs1_seqlen8192.

✏️ Learnings added
Learnt from: athitten
Repo: NVIDIA-NeMo/Megatron-Bridge PR: 2067
File: src/megatron/bridge/recipes/llama/llama3.py:1013-1017
Timestamp: 2026-01-28T18:47:54.691Z
Learning: In src/megatron/bridge/recipes/llama/llama3.py, all Llama 3 70B pretraining configs (llama3_70b_pretrain_config, llama3_70b_16k_pretrain_config, llama3_70b_64k_pretrain_config) intentionally use the same comm overlap config `userbuffers_bf16_h100_h8192_tp4_mbs1_seqlen8192`, regardless of their TP setting (TP=4 or TP=8). This is because there is no TP=8 comm overlap variant available for h8192 models on H100 - the TP=4 config is the closest match.
<!-- </add_learning>

Note: Learnings are effective only in the context of similar code segments. To apply general coding guidelines, please configure review instructions. You can manage existing learnings in the app.

@athitten
Copy link
Copy Markdown
Contributor Author

/ok to test 15677f0

@athitten
Copy link
Copy Markdown
Contributor Author

/ok to test e47a886

Signed-off-by: Abhishree <abhishreetm@gmail.com>
Signed-off-by: Abhishree <abhishreetm@gmail.com>
Signed-off-by: Abhishree <abhishreetm@gmail.com>
Signed-off-by: Abhishree <abhishreetm@gmail.com>
Signed-off-by: Abhishree <abhishreetm@gmail.com>
@athitten
Copy link
Copy Markdown
Contributor Author

/ok to test 7c9f47d

Comment on lines +159 to +165
# If --hf-path is provided, override the model's HuggingFace path
if args.hf_path:
logger.info(f"Using custom HuggingFace path: {args.hf_path}")
recipe_kwargs["hf_path"] = args.hf_path
# Import AutoBridge to create a new model provider with the custom HF path
from megatron.bridge.models import AutoBridge

cfg: ConfigContainer = pretrain_config(**recipe_kwargs)
logger.info("Loaded base configuration")
cfg.model = AutoBridge.from_hf_pretrained(args.hf_path).to_megatron_provider(load_weights=False)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yaoyu-33 won't this be an issue since users now have to re-apply other model configs set as default in the recipe?

Signed-off-by: Abhishree <abhishreetm@gmail.com>
@athitten
Copy link
Copy Markdown
Contributor Author

/ok to test ad67f81

@yaoyu-33
Copy link
Copy Markdown
Contributor

Wait to merge after code freeze, release branch cut

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants