Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,22 @@
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
generation_config = GenerationConfig.from_pretrained(MODEL_ID)
config = Qwen3MoeConfig(
vocab_size=len(tokenizer.vocab),
vocab_size=151936,
hidden_size=8,
num_attention_heads=4,
num_key_value_heads=2,
num_hidden_layers=2,
intermediate_size=32,
num_experts=4,
num_experts_per_tok=2,
max_position_embeddings=40960,
rope_theta=1000000.0,
norm_topk_prob=True,
bos_token_id=151643,
eos_token_id=151645,
# Forwarded via kwargs (not Qwen3MoeConfig fields, but PretrainedConfig accepts arbitrary kwargs):
head_dim=128,

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that will be useful for MFU utils. We can land this and I'll rebase 👍🏼 @qgallouedec

max_window_layers=48,
)
model = Qwen3MoeForCausalLM(config).to(dtype=torch.bfloat16)
init_weights_tiny_model(model)
Expand Down
13 changes: 12 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def pytest_runtest_makereport(item, call):

MODEL_REVISIONS = {
# Add model_id: revision mappings here to test PRs
"trl-internal-testing/tiny-Qwen3MoeForCausalLM": "refs/pr/1",
}


Expand All @@ -63,7 +64,14 @@ def apply_model_revisions(monkeypatch):
if not MODEL_REVISIONS:
return

from transformers import PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoModelForSequenceClassification,
PreTrainedModel,
PreTrainedTokenizerBase,
ProcessorMixin,
)

def create_classmethod_wrapper(original_classmethod):
# Extract the underlying function from the classmethod
Expand All @@ -83,6 +91,9 @@ def wrapper(cls, pretrained_model_name_or_path, *args, **kwargs):

# Patch all transformers Auto* classes
for cls in [
AutoConfig,
AutoModelForCausalLM,
AutoModelForSequenceClassification,
PreTrainedModel,
PreTrainedTokenizerBase,
ProcessorMixin,
Expand Down
87 changes: 86 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,17 @@
import torch.nn.functional as F
import transformers
from packaging.version import Version
from transformers import AutoModelForCausalLM
from transformers import AutoConfig, AutoModelForCausalLM
from transformers.testing_utils import torch_device
from transformers.utils import is_peft_available

from trl import ModelConfig
from trl.trainer.utils import (
RepeatSampler,
_ChunkedLogProbFunction,
adjusted_mfu,
compute_flops_per_token,
compute_mfu,
entropy_from_logits,
flush_left,
generate_model_card,
Expand Down Expand Up @@ -1283,3 +1286,85 @@ def test_backward(self, model_id, temperature):
chunked_grad = model_chunked.lm_head.weight.grad.clone()

torch.testing.assert_close(chunked_grad, ref_grad, atol=5e-2, rtol=5e-2)


class TestComputeFlopsPerToken(TrlTestCase):
DENSE_MODEL_ID = "trl-internal-testing/tiny-Qwen3ForCausalLM"
MOE_MODEL_ID = "trl-internal-testing/tiny-Qwen3MoeForCausalLM"

def test_seq_scaling_linear(self):
# Attention-score FLOPs per token scale linearly with seq_len; everything else
# is seq-len-independent. Doubling seq_len should double the seq-dependent delta,
# which differences cancel out from. `F(32k) - F(16k) == 2 * (F(16k) - F(8k))`.
cfg = AutoConfig.from_pretrained(self.DENSE_MODEL_ID)
f_8k = compute_flops_per_token(cfg, 8192)
f_16k = compute_flops_per_token(cfg, 16384)
f_32k = compute_flops_per_token(cfg, 32768)
assert f_32k - f_16k == 2 * (f_16k - f_8k)

def test_tied_vs_untied_lm_head(self):
# Untied lm_head adds `2 * V * h` forward FLOPs, ×3 for fwd+bwd.
cfg = AutoConfig.from_pretrained(self.DENSE_MODEL_ID)
cfg.tie_word_embeddings = True
f_tied = compute_flops_per_token(cfg, 16384)
cfg.tie_word_embeddings = False
f_untied = compute_flops_per_token(cfg, 16384)
expected_delta = 3 * 2 * cfg.vocab_size * cfg.hidden_size
assert f_untied - f_tied == expected_delta

def test_moe_active_vs_total_experts(self):
# Doubling `num_experts_per_tok` (active experts) changes FLOPs by exactly the
# routed-experts contribution: `num_experts_per_tok × 3 matmuls × 2 × h × moe_intermediate`
# per MoE layer, ×3 for fwd+bwd. Holding `num_local_experts` constant pins the
# router term so the delta is purely the active-expert math.
cfg = AutoConfig.from_pretrained(self.MOE_MODEL_ID)
cfg.num_experts_per_tok = 1
f_lo = compute_flops_per_token(cfg, 16384)
cfg.num_experts_per_tok = 2
f_hi = compute_flops_per_token(cfg, 16384)
moe_layers = sum(1 for i in range(cfg.num_hidden_layers) if i % cfg.decoder_sparse_step == 0)
per_expert_per_layer = 2 * 3 * cfg.hidden_size * cfg.moe_intermediate_size
expected_delta = 3 * moe_layers * (2 - 1) * per_expert_per_layer
assert f_hi - f_lo == expected_delta


class TestComputeMfu(TrlTestCase):
def test_perfect_utilization(self):
# If aggregate TPS is exactly `peak * world_size / flops_per_token`, MFU is 100%.
flops = 100e9
peak = 989.5e12
world_size = 8
tps = peak * world_size / flops
assert compute_mfu(flops, tps, world_size, peak_flops_per_device=peak) == pytest.approx(100.0)


class TestAdjustedMfu(TrlTestCase):
MOE_MODEL_ID = "trl-internal-testing/tiny-Qwen3MoeForCausalLM"

def test_consistent_with_formula(self):
# `adjusted_mfu(mfu, cfg, seq_len) == mfu * (full - half_attn) / full`, with
# `full = compute_flops_per_token(cfg, seq_len)` and
# `half_attn = L * 3 * 2 * n_heads * head_dim * seq_len`. Cross-check the two helpers.
cfg = AutoConfig.from_pretrained(self.MOE_MODEL_ID)
seq_len = 16384
flops_full = compute_flops_per_token(cfg, seq_len)
half_attn = cfg.num_hidden_layers * 3 * 2 * cfg.num_attention_heads * cfg.head_dim * seq_len
expected = 100.0 * (flops_full - half_attn) / flops_full
assert adjusted_mfu(100.0, cfg, seq_len) == pytest.approx(expected)

def test_proportional_to_input(self):
# The correction is purely multiplicative in `mfu`. `adjusted_mfu(2*x, ...)` should
# equal `2 * adjusted_mfu(x, ...)`.
cfg = AutoConfig.from_pretrained(self.MOE_MODEL_ID)
a = adjusted_mfu(50.0, cfg, 16384)
b = adjusted_mfu(100.0, cfg, 16384)
assert b == pytest.approx(2 * a)

def test_decreases_with_seq_len(self):
# Longer sequences → attention takes a larger share of total compute → causal
# correction subtracts a larger absolute amount → factor strictly decreases.
cfg = AutoConfig.from_pretrained(self.MOE_MODEL_ID)
f_short = adjusted_mfu(100.0, cfg, 4096)
f_med = adjusted_mfu(100.0, cfg, 16384)
f_long = adjusted_mfu(100.0, cfg, 65536)
assert f_short > f_med > f_long
114 changes: 114 additions & 0 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import transformers
from accelerate import PartialState, logging
from huggingface_hub import ModelCard, ModelCardData
from packaging.version import Version
from torch.utils.data import Sampler
from transformers import (
AutoConfig,
Expand Down Expand Up @@ -1319,3 +1320,116 @@ def _chunked_forward(
}

model.forward = types.MethodType(_chunked_forward, model)


def compute_flops_per_token(config: PretrainedConfig, seq_len: int) -> int:
"""
Estimate training FLOPs per token for a causal language model (forward + backward).

Supports dense and MoE architectures. Backward is assumed to cost 2× the forward pass, so total training FLOPs = 3
× forward FLOPs. The attention-score term uses the non-causal convention (every token attends to the full
`seq_len`, matching PaLM / Megatron / nanoGPT); pass the resulting MFU through [`adjusted_mfu`] for the Llama /
DeepSpeed Ulysses causal-corrected convention.

Args:
config ([`~transformers.PretrainedConfig`]):
The model configuration.
seq_len (`int`):
The sequence length used for training.

Returns:
`int`: Estimated training FLOPs per token.
"""
h = config.hidden_size
L = config.num_hidden_layers
V = config.vocab_size
n_heads = config.num_attention_heads
n_kv_heads = config.num_key_value_heads
head_dim = config.head_dim

# Attention: Q/K/V/O projections + attention score (Q·Kᵀ and attn·V).
qkv_flops = 2 * h * (n_heads * head_dim + 2 * n_kv_heads * head_dim)
o_proj_flops = 2 * n_heads * head_dim * h
attn_score_flops = 2 * 2 * n_heads * head_dim * seq_len
attn_flops = qkv_flops + o_proj_flops + attn_score_flops

# MoE dispatch: `num_experts_per_tok` is the canonical MoE marker — present on Mixtral,
# Qwen3-MoE, DeepSeek-V2, etc.; absent on dense configs.
num_experts_per_tok = getattr(config, "num_experts_per_tok", None)
if num_experts_per_tok is None:
mlp_flops = 2 * 3 * h * config.intermediate_size
total_layer_flops = L * (attn_flops + mlp_flops)
else:
# Routed experts (gate + up + down, 3 matmuls each) + router.
if Version(transformers.__version__) >= Version("5.1.0"):
num_experts = config.num_local_experts
else:
num_experts = config.num_experts
moe_mlp_flops = num_experts_per_tok * 2 * 3 * h * config.moe_intermediate_size
moe_mlp_flops += 2 * h * num_experts
dense_mlp_flops = 2 * 3 * h * config.intermediate_size # interspersed dense layers
sparse_step = config.decoder_sparse_step
total_layer_flops = sum(
attn_flops + (moe_mlp_flops if layer_idx % sparse_step == 0 else dense_mlp_flops) for layer_idx in range(L)
)

embed_flops = 2 * V * h
lm_head_flops = 0 if config.tie_word_embeddings else 2 * V * h

forward_flops = total_layer_flops + embed_flops + lm_head_flops
return 3 * forward_flops


def compute_mfu(
flops_per_token: int,
tokens_per_second: float,
world_size: int,
peak_flops_per_device: float = 989.5e12,
) -> float:
"""
Compute Model FLOPs Utilization (MFU) as a percentage.

The caller is responsible for correcting `tokens_per_second` for any parallelism dimension that causes the
trainer's token counter to over-count (e.g. context parallelism, sequence parallelism, tensor parallelism — every
rank in those dims sees the same input tokens).

Args:
flops_per_token (`int`):
Training FLOPs per token (from [`compute_flops_per_token`]).
tokens_per_second (`float`):
Aggregate tokens per second across all devices, after any parallelism corrections.
world_size (`int`):
Number of devices (GPUs).
peak_flops_per_device (`float`, *optional*, defaults to `989.5e12`):
Theoretical peak FLOPs per device in bf16. Defaults to H100 SXM5.

Returns:
`float`: MFU as a percentage (0-100).
"""
return 100 * (flops_per_token * tokens_per_second) / (peak_flops_per_device * world_size)


def adjusted_mfu(mfu: float, config: PretrainedConfig, seq_len: int) -> float:
"""
Apply a causal-masking correction to an MFU computed with [`compute_flops_per_token`].

[`compute_flops_per_token`] uses the non-causal attention convention (every token treated as attending to the full
`seq_len`, matching PaLM / Megatron / nanoGPT). With causal masking, only half of the attention-score FLOPs (`Q·Kᵀ`
and `attn·V`) are actually performed. This function subtracts that half from the per-token total and rescales `mfu`
accordingly. Use it to compare against reports that follow the Llama 2/3 / DeepSpeed Ulysses convention.

Args:
mfu (`float`):
MFU as a percentage, computed via [`compute_mfu`] (i.e., using the non-causal [`compute_flops_per_token`]).
config ([`~transformers.PretrainedConfig`]):
The model configuration.
seq_len (`int`):
The sequence length used for training.

Returns:
`float`: Causal-corrected MFU as a percentage.
"""
flops_full = compute_flops_per_token(config, seq_len)
# Half of the attention-score FLOPs (Q·Kᵀ and attn·V), per layer, ×3 for fwd+bwd.
half_attn_score = config.num_hidden_layers * 3 * 2 * config.num_attention_heads * config.head_dim * seq_len
return mfu * (flops_full - half_attn_score) / flops_full
Loading