Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
e4bad3b
block scaling
jiemingz Jun 24, 2025
d19565d
big fixes
jiemingz Jun 25, 2025
2daa0c6
fix pow2
jiemingz Jun 26, 2025
880fa65
use deepgemm
jiemingz Jun 27, 2025
9f179b9
pow2 x and w
jiemingz Jun 28, 2025
da5e3b9
fist last bf16
jiemingz Jun 28, 2025
b9798e9
fix rebase
jiemingz Jun 28, 2025
873f093
fix bf16 first last
jiemingz Jun 29, 2025
8607d17
fix again
jiemingz Jun 30, 2025
b537c59
remove most of the hardcoding
jiemingz Jul 1, 2025
e1839b8
refactor, use fp8 config
jiemingz Jul 1, 2025
27908e8
more refactoring
jiemingz Jul 1, 2025
494c0a4
add deps
jiemingz Jul 8, 2025
3167e85
addresss comments
jiemingz Jul 14, 2025
4127842
linting
jiemingz Jul 14, 2025
a5297e6
tests
jiemingz Jul 14, 2025
9523eb0
separate first last layers in bf16
jiemingz Jul 14, 2025
4a32ad2
lint
jiemingz Jul 14, 2025
ca502a2
ensure importance sampling on
jiemingz Jul 14, 2025
f8b2e99
add fp8 config
jiemingz Jul 21, 2025
b7d4a21
fix TP and async engine
jiemingz Jul 24, 2025
7633ee2
lint
jiemingz Jul 24, 2025
51c2cf0
Update grpo.py
jiemingz Jul 24, 2025
dbd3572
add doc, fix single gpu case
jiemingz Jul 29, 2025
10dac2d
fix async
jiemingz Jul 29, 2025
f441f59
fix rebase
jiemingz Aug 7, 2025
dd8d345
fix tests
jiemingz Aug 10, 2025
1bfa5e1
fix lint
jiemingz Aug 11, 2025
4d9ecbd
fix rebase
jiemingz Aug 11, 2025
4d19688
Lint/copyright
SahilJain314 Aug 12, 2025
36dae19
address comments
jiemingz Aug 12, 2025
9fc983a
Update docs/fp8.md
jiemingz Aug 12, 2025
192764f
fix sphinx
jiemingz Aug 14, 2025
e715b7a
fix noncolocate
jiemingz Aug 18, 2025
6c85b1e
skip fp8 tests on <h100
jiemingz Aug 20, 2025
2573ae5
uv lock
jiemingz Aug 20, 2025
b2d7e9a
add functional
jiemingz Aug 21, 2025
5cf2e68
Merge branch 'main' into jiemingz/fp8_block
jiemingz Aug 21, 2025
33988a8
Update grpo-llama3.1-8b-instruct-1n8g-megatron-fp8.yaml
jiemingz Aug 21, 2025
4d30861
add missed cfgs
jiemingz Aug 21, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/assets/fp8_curves.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
39 changes: 39 additions & 0 deletions docs/fp8.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# FP8 for NeMo-RL

This module provides a suite of tools to enable FP8 quantization for large language models. This module is still in developement. Currently we support FP8 generation, using Deepseek style FP8 (sub channel scaling).

NeMo-RL monkey patches several vLLM functions to enable FP8 generations for reinforcement learning. The `init_fp8` function patches key `vLLM` components when initialized:
1. **`RayDistributedExecutor`**: For multi-GPU inference, the executor is patched to ensure that every worker process applies the same FP8 patches before model initialization.
2. **Quantization Utilities**: Functions within `vllm.model_executor.layers.quantization` are replaced with versions that support power-of-2 scaling and other custom features.
3. **Weight Loading**: A custom `load_weights` function handles the on-the-fly quantization of model weights from a higher-precision format to FP8 with the correct scaling factors.

---

## Usage

FP8 generations are recommended to be configured with the following settings:

```
loss_fn:
# importance sampling helps improve stability
use_importance_sampling_correction: true

policy:
generation:
vllm_cfg:
precision: 'fp8'
# DeepGemm is much more performant than vLLM's default cutlass fp8 subchannel scaling kernels
use_deep_gemm: true
# Keeping the first and last three layers in bf16 reduces the multi-token error without
# a signficant effect to performance
num_last_layers_in_bf16: 3
num_first_layers_in_bf16: 1
# Use FP32 scaling factors. Rounding scaling factors to the nearest pow2 may improve quantization
# fidelity however this feature is still under research.
use_weight_pow2_scale: False
use_activation_pow2_scale: False
```

## Accuracy

We observe on the Llama 8b recipe a ~5% accuracy loss is incurred with FP8 generations. Convergence is still under active research and FP8 generations should be used with caution. We are investigating ways to close the accuracy gap and further improve performance.
1 change: 1 addition & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ testing.md
documentation.md
debugging.md
nsys-profiling.md
fp8.md
guides/use-custom-vllm.md
apidocs/index.rst
```
Expand Down
3 changes: 3 additions & 0 deletions examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ policy:
gpu_memory_utilization: 0.6
max_model_len: ${policy.max_total_sequence_length}
enforce_eager: False
use_deep_gemm: False
num_last_layers_in_bf16: 0
num_first_layers_in_bf16: 0
colocated:
# true: generation shares training GPUs
# false: uses dedicated generation resources
Expand Down
13 changes: 13 additions & 0 deletions examples/configs/grpo_math_8B_megatron_fp8.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# GRPO Algorithm Configuration
defaults: "grpo_math_8B_megatron.yaml"

loss_fn:
use_importance_sampling_correction: true

policy:
generation:
vllm_cfg:
precision: 'fp8'
use_deep_gemm: true
num_last_layers_in_bf16: 0
num_first_layers_in_bf16: 0
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
grpo:
num_prompts_per_step: 64
num_generations_per_prompt: 32
max_rollout_turns: 1
max_num_steps: 500
normalize_rewards: true
use_leave_one_out_baseline: true
val_period: 10
val_at_start: false
max_val_samples: 256
val_batch_size: 256
seed: 42
loss_fn:
reference_policy_kl_penalty: 0.01
ratio_clip_min: 0.2
ratio_clip_max: 0.2
ratio_clip_c: null
use_on_policy_kl_approximation: false
use_importance_sampling_correction: True
token_level_loss: true
checkpointing:
enabled: true
checkpoint_dir: results/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8
metric_name: val_reward
higher_is_better: true
keep_top_k: 3
save_period: 10
checkpoint_must_save_by: null
policy:
model_name: meta-llama/Llama-3.1-8B-Instruct
tokenizer:
name: meta-llama/Llama-3.1-8B-Instruct
train_global_batch_size: 512
train_micro_batch_size: 1
generation_batch_size: 32
logprob_batch_size: 2
max_total_sequence_length: 4096
precision: bfloat16
make_sequence_length_divisible_by: 1
max_grad_norm: 1

dtensor_cfg:
enabled: False

dynamic_batching:
enabled: False

sequence_packing:
enabled: True
train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}}
algorithm: "modified_first_fit_decreasing"
sequence_length_round: 64

megatron_cfg:
enabled: True
empty_unused_memory_level: 1
converter_type: "LlamaForCausalLM"
tensor_model_parallel_size: 1
pipeline_model_parallel_size: 2
context_parallel_size: 1
expert_tensor_parallel_size: 1
expert_model_parallel_size: 1
sequence_parallel: False
pipeline_dtype: ${policy.precision}
num_layers_in_first_pipeline_stage: null
num_layers_in_last_pipeline_stage: null
freeze_moe_router: True
moe_router_dtype: "fp64"
moe_router_load_balancing_type: "none" # "seq_aux_loss" causes logprob error divergence for grpo
moe_router_bias_update_rate: 0.0 # by default, disable bias updates for grpo
apply_rope_fusion: True
activation_checkpointing: True
defer_fp32_logits: True

optimizer:
optimizer: "adam"
lr: 5.0e-7
min_lr: 5.0e-8
weight_decay: 0.0
bf16: True
fp16: False
params_dtype: "float32"

adam_beta1: 0.9
adam_beta2: 0.999
adam_eps: 1e-8

use_distributed_optimizer: True
use_precision_aware_optimizer: True

clip_grad: ${policy.max_grad_norm}

scheduler:
start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay}
end_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay}
weight_decay_incr_style: "constant"
lr_decay_style: "constant"
lr_decay_iters: null
lr_warmup_iters: 2
lr_warmup_init: 5.0e-8

distributed_data_parallel_config:
grad_reduce_in_fp32: False
overlap_grad_reduce: True
overlap_param_gather: True
average_in_collective: True
use_custom_fsdp: False
data_parallel_sharding_strategy: "optim_grads_params"

generation:
backend: vllm
max_new_tokens: 4096
temperature: 1
top_p: 1
top_k: null
stop_token_ids:
- 128009
stop_strings: null
vllm_cfg:
async_engine: false
precision: 'fp8'
tensor_parallel_size: 1
pipeline_parallel_size: 1
gpu_memory_utilization: 0.6
max_model_len: 4096
enforce_eager: False
use_deep_gemm: true
num_last_layers_in_bf16: 0
num_first_layers_in_bf16: 0
colocated:
enabled: true
resources:
gpus_per_node: null
num_nodes: null
data:
max_input_seq_length: 4096
prompt_file: examples/prompts/cot.txt
system_prompt_file: null
dataset_name: OpenMathInstruct-2
shuffle: true
env:
math:
num_workers: 8
logger:
log_dir: logs/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8
num_val_samples_to_print: 0
wandb_enabled: true
tensorboard_enabled: true
mlflow_enabled: false
monitor_gpus: true
wandb:
project: nemo-rl
name: grpo-llama3.1-8b-instruct-1n8g-megatron-fp8
tensorboard: {}
gpu_monitoring:
collection_interval: 10
flush_interval: 10
cluster:
gpus_per_node: 8
num_nodes: 4
5 changes: 5 additions & 0 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,11 @@ def setup(
)
elif backend == "vllm":
generation_config = cast(VllmConfig, generation_config)
if generation_config["vllm_cfg"]["precision"] == "fp8":
assert loss_config["use_importance_sampling_correction"] is True, (
"Importance sampling must be enabled for vLLM FP8 generation for good convergence!"
)

policy_generation = VllmGeneration(
cluster=inference_cluster, config=generation_config
)
Expand Down
Loading
Loading