Skip to content

[Feature] add gptoss continue train bf16-fp8 (sft) example [part1 - mcore]#2383

Closed
yiakwy-xpu-ml-framework-team wants to merge 4 commits intoNVIDIA:mainfrom
yiakwy-xpu-ml-framework-team:add_gptoss_example
Closed

[Feature] add gptoss continue train bf16-fp8 (sft) example [part1 - mcore]#2383
yiakwy-xpu-ml-framework-team wants to merge 4 commits intoNVIDIA:mainfrom
yiakwy-xpu-ml-framework-team:add_gptoss_example

Conversation

@yiakwy-xpu-ml-framework-team
Copy link

@yiakwy-xpu-ml-framework-team yiakwy-xpu-ml-framework-team commented Nov 24, 2025

What does this PR do ?

⚠️ For major changes (either in lines of code or in its impact), please make sure to first share discuss a design-doc with the team.

Add gptoss 20b training example in Hopper platform:

Env :

  • cu 12.8 + torch 2.9.1
  • fa2.8.3 , fa3 (latest)
  • TransformerrEngine [pytorch ](latest, 2.9.0)
  • triton 3.4 (to work with triton sorting kernel)
  • Base Image : SGLang 0.5.0rc2

Support parallel scheme:

  • PP=1, EP=8
  • PP=2, EP=4
  • TP=8, ETP=1, EP=8, PP=1

ETP must be 1 ,since ETP should not be > 1 when add_bias is True (required by GptOSS) .

Snapshot

full gptoss 24 layers with TP8-EP8:
截屏2025-11-27 16 14 54

截屏2025-11-27 16 14 59

Steps to reproduce

  • Generate distributed checkpoint:

    torchrun $DISTRIBUTED_ARGS convert_mcore_bf16_checkpoint_from_hf.py 2>&1 | tee megatron_fwd.log
    
  • start training jobs

    # slurm args
    bash training_gptoss_20b_120b_h100_bf16_fp8.sh   
    

Changes

  • megatron core:

    • experts : output should be applied bias with unpadded tokens_per_expert
  • GptOss Yarn Config:

    • add GptOss yarn config before model construction

Other Rleated Issue:

Contribution process

flowchart LR
    A[Pre-checks] --> B[PR Tests]
    subgraph Code Review/Approval
        C1[Expert Review] --> C2[Final Review]
    end
    B --> C1
    C2 --> D[Merge]
Loading

Pre-checks

  • I want this PR in a versioned release and have added the appropriate Milestone (e.g., Core 0.8)
  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

The following process is enforced via the CODEOWNERS file for changes into megatron/core. For changes outside of megatron/core, it is up to the PR author whether or not to tag the Final Reviewer team.

For MRs into `main` branch

(Step 1): Add PR label Expert Review

(Step 2): Collect the expert reviewers reviews

  1. Attach the Expert Review label when your PR is ready for review.
  2. GitHub auto-assigns expert reviewers based on your changes. They will get notified and pick up your PR soon.

⚠️ Only proceed to the next step once all reviewers have approved, merge-conflict are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

(Step 3): Final Review

  1. Add Final Review label
  2. GitHub auto-assigns final reviewers based on your changes. They will get notified and pick up your PR soon.

(Optional Step 4): Cherry-pick into release branch

If this PR also needs to be merged into core_r* release branches, after this PR has been merged, select Cherry-pick to open a new PR into the release branch.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either eharper@nvidia.com or zijiey@nvidia.com.

Merging your PR

Any member of core-adlr and core-nemo will be able to merge your PR.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Nov 24, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@sbhavani
Copy link
Contributor

sbhavani commented Dec 1, 2025

@yiakwy-xpu-ml-framework-team thanks for creating this example!

We are currently refactoring config management to improve validation and reduce the number of default args just like in Megatron Bridge. I'd recommend we wait until the refactor is done this month and then merge the example.

I also noticed you require other changes outside of examples/ and arguments.py. I think it'd better to split those changes into separate PRs (experts.py bug fix, doc fixes, etc.) from the GPT-OSS example.


from megatron.bridge import AutoBridge
from megatron.bridge.utils.common_utils import get_last_rank, print_rank_0
from megatron.bridge.training.model_load_save import load_megatron_model, save_megatron_model, load_tokenizer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, it's not ideal to import bridge from megatron-lm side, it will cause a cyclic dependency. Also megatron-lm env doesnt require users to install bridge.

do you think it's okay to put the example in bridge?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yaoyu-33 I see, this is a tool in example indendent from Megatron-Core. I didn't encoutner the cyclic reference problem.

I use this script to generate megatron distributed checkpoint and verify it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's more like we are not asking megatron-lm users to install bridge, so it's a bit confusing to put things here. I think it needs other people's opinions to see what's the best way.

@@ -0,0 +1,43 @@
#/usr/bin/bash
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

llama changes should be another pr, if this pr is mostly gpt-oss related.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed we only added verification for llama3. I just updated it from my local repo.

The mian focus this script is gptoss. Yes it is better to be in separate PR but if it is not too inconvenient to you and we can add it to the repo in this PR D.

get_embedding_ranks: Optional[Callable[[List[int], Optional[int]], List[int]]] = None,
get_position_embedding_ranks: Optional[Callable[[List[int], Optional[int]], List[int]]] = None,
create_gloo_process_groups: bool = True,
create_gloo_process_groups: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a major change that affect all training, it will increase the mem footprint, what's the reason?

Copy link
Author

@yiakwy-xpu-ml-framework-team yiakwy-xpu-ml-framework-team Dec 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I will mute it if memory footprint is increased. I didn't observed side effects and I just find that by deault many gloo connection was created in CPU side.

@fengxy-03
Copy link

@yiakwy-xpu-ml-framework-team @yaoyu-33 Hi, I've implemented a GPT-OSS model (0.11B) based on your guidelines using Megatron-LM 0.16.0rc0. During training, the throughput per GPU is only around 1.0 TFLOP/s, which seems abnormal for this configuration.

Could you please take a look at my script and logs to see if there are any obvious misconfigurations?

Training Script

SEQ_LENGTH=8192
MAX_LENGTH=8192
TRAIN_SAMPLES=1518124
LAST_TRAIN_SAMPLES=0
LR_DECAY_SAMPLES=$(((TRAIN_SAMPLES - LAST_TRAIN_SAMPLES) * 80 / 100))
CHECKPOINT_PATH=
# ================end================
TOKENIZER_TYPE=SentencePieceTokenizer
TOKENIZER_MODEL=
MICRO_BATCH_SIZE=1
GLOBAL_BATCH_SIZE=128

DISTRIBUTED_ARGS=" \
    --nnodes=1 \
    --nproc_per_node=8 \
    --node_rank=0 \
    --master_addr=localhost \
    --master_port=6000"

MODEL_ARGS=" \
    --no-masked-softmax-fusion \
    --transformer-impl transformer_engine \
    --disable-bias-linear \
    --untie-embeddings-and-output-weights \
    --no-rope-fusion \
    --normalization RMSNorm \
    --num-layers 12 \
    --hidden-size 512 \
    --ffn-hidden-size 2048 \
    --num-attention-heads 64 \
    --group-query-attention \
    --num-query-groups 8 \
    --seq-length 8192 \
    --max-position-embeddings 8192 \
    --use-mcore-models \
    --rotary-percent 1.0 \
    --rope-type yarn \
    --position-embedding-type yarn
    --rotary-base 10000 \
    --no-bias-gelu-fusion \
    --export-force-local-attention \
    --no-bias-dropout-fusion \
    --quick-geglu \
    --glu-linear-offset 1.0 \
    --softmax-type learnable \
    --window-attn-skip-freq 2 \
    --activation-func-clamp-value 7.0 \
    --window-size 128,0 \
    --enable-gpt-oss"

MOE_ARGS=" \
    --num-experts 4 \
    --moe-router-topk 2 \
    --moe-router-load-balancing-type aux_loss \
    --moe-aux-loss-coeff 1e-3 \
    --moe-grouped-gemm \
    --moe-token-dispatcher-type alltoall \
    --overlap-param-gather \
    --overlap-grad-reduce \
    --moe-ffn-hidden-size 2048 \
    --moe-router-dtype fp32 \
    --moe-z-loss-coeff 1e-3 \
    --moe-permute-fusion"

DATA_ARGS=" \
    --num-workers 8 \
    --dataloader-type cyclic \
    --tokenizer-type ${TOKENIZER_TYPE} \
    --tokenizer-model ${TOKENIZER_MODEL} \
    --data-path \
    --split 1000,0,0 \
    --no-create-attention-mask-in-dataloader"

TRAINING_ARGS=" \
    --micro-batch-size ${MICRO_BATCH_SIZE} \
    --global-batch-size ${GLOBAL_BATCH_SIZE} \
    --lr 1.0e-5 \
    --train-samples ${TRAIN_SAMPLES} \
    --lr-decay-samples ${LR_DECAY_SAMPLES} \
    --lr-decay-style cosine \
    --min-lr 1.0e-6 \
    --weight-decay 0.1 \
    --lr-warmup-fraction 0.05 \
    --clip-grad 1.0 \
    --bf16 \
    --use-flash-attn \
    --attention-softmax-in-fp32 \
    --accumulate-allreduce-grads-in-fp32 \
    --disable-bf16-reduced-precision-matmul \
    --recompute-activations"

MODEL_PARALLEL_ARGS=" \
    --tensor-model-parallel-size 4 \
    --pipeline-model-parallel-size 1 \
    --expert-model-parallel-size 2 \
    --sequence-parallel \
    --context-parallel-size 1 \
    --use-distributed-optimizer \
    --fp8-format hybrid \
    --fp8-param-gather \
    --fp8-amax-compute-algo max \
    --fp8-amax-history-len 1024"
    
LOGGING_ARGS=" \
    --log-interval 1 \
    --save-interval 10000 \
    --eval-interval 50000000 \
    --eval-iters 0 \
    --save $CHECKPOINT_PATH \
    --tensorboard-dir "${CHECKPOINT_PATH}/tensorboard" \
    --wandb-project ${WANDB_PROJECT:-"gpt-oss"} \
    --wandb-exp-name ${WANDB_NAME:-"gpt-oss-test"} \
    --moe-per-layer-logging \
    --no-load-optim \
    --no-load-rng \
    --log-throughput"


python -m torch.distributed.run ${DISTRIBUTED_ARGS} pretrain_gpt.py \
    ${MODEL_ARGS} \
    ${MOE_ARGS} \
    ${DATA_ARGS} \
    ${TRAINING_ARGS} \
    ${MODEL_PARALLEL_ARGS} \
    ${LOGGING_ARGS}

Logs

 [2026-01-16 15:01:05] iteration        3/   11860 | consumed samples:          384 | elapsed time per iteration (ms): 75210.0 | throughput per GPU (TFLOP/s/GPU): 1.1 | learning rate: 6.323595E-08 | global batch size:   128 | lm loss: 6.102232E+00 | z_loss: 2.145113E+00 | load_balancing_loss: 1.155455E+00 | loss scale: 1.0 | grad norm: 75.384 | number of skipped iterations:   0 | number of nan iterations:   0 |
 [2026-01-16 15:02:20] iteration        4/   11860 | consumed samples:          512 | elapsed time per iteration (ms): 75126.0 | throughput per GPU (TFLOP/s/GPU): 1.1 | learning rate: 8.431460E-08 | global batch size:   128 | lm loss: 6.109521E+00 | z_loss: 2.139967E+00 | load_balancing_loss: 1.154540E+00 | loss scale: 1.0 | grad norm: 74.716 | number of skipped iterations:   0 | number of nan iterations:   0 |

@Phlip79
Copy link
Member

Phlip79 commented Mar 4, 2026

Please reference Megatron-Bridge for how to use bf16 and fp8 for GPT-OSS.

@Phlip79 Phlip79 closed this Mar 4, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants