Skip to content
Merged
Show file tree
Hide file tree
Changes from 40 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
4c6b67b
kv-cache: prepare clean commit without excluded files
Nov 11, 2025
312d204
kv cache fp8 code refine and cleanup
Nov 11, 2025
e2cbca8
Fix indentiation error. Enable using environment variables to set FP8…
Nov 11, 2025
c92dc55
Correct typos
Nov 11, 2025
ad09785
Update fp8.py
Nov 11, 2025
dca01e2
Remove dapo.py from response_datasets
Nov 11, 2025
03e47a9
Update sanity check in grpo.py. Remove redundant code in megatron bac…
Nov 11, 2025
96955de
Remove redundant comments
Nov 11, 2025
b1214cd
Remove _hook_builder
Nov 11, 2025
be8853d
rebase and update refitting process
zpqiu Nov 11, 2025
4ae3ec0
fix refitting bugs after rebase
zpqiu Nov 11, 2025
d66c262
Refactor FP8 KV cache scale handling by centralizing vLLM parameter n…
Nov 14, 2025
bc26b40
lint check
zpqiu Nov 14, 2025
7f709b4
Update to correct BF16 issue with load_weights
Nov 15, 2025
964d929
WIP: changes before rebase
Nov 16, 2025
ede399c
Code draft to support dynamic kv scales calculation
Nov 17, 2025
ff455d6
Refit should only takes care of kv_scales when kv_cache_dtype is fp8 …
Nov 17, 2025
ac2b5ed
lint check
zpqiu Nov 17, 2025
9362035
remove debug prints
zpqiu Nov 17, 2025
f35662d
make refitting with kv scales cleaner
zpqiu Nov 17, 2025
667661e
remove debug print; raise errors of calibration process; refine refit…
zpqiu Nov 18, 2025
b84b10e
remove old hotfix about save_ckpt
zpqiu Nov 18, 2025
50c4abd
avoid importing vllm at grpo.py
zpqiu Nov 18, 2025
0dbf7ab
add placeholder func and parameter for dtensor path
zpqiu Nov 18, 2025
0ea586b
Refit should take care of kv_scales in the validation phase
sharonyu-115 Nov 21, 2025
8f759a1
Remote TODO comment
sharonyu-115 Nov 21, 2025
ea3e500
Merge branch 'main' into kv-cache-fp8
guyueh1 Nov 24, 2025
6f3bed7
Rename the example yaml file to grpo_math_qwen3_8B_fp8_kvcache.yaml a…
sharonyu-115 Nov 25, 2025
4089ab5
Add kv_cache fp8 test case to test_vllm_generation_with_megatron_trai…
sharonyu-115 Nov 25, 2025
ac6f66c
update pp>1 assert info
zpqiu Nov 26, 2025
af60c9a
update guard statements in DTensor path files
zpqiu Nov 26, 2025
f150419
Merge branch 'main' into kv-cache-fp8
zpqiu Nov 26, 2025
3a37119
add l1 test; update config yaml
zpqiu Nov 26, 2025
231a739
at first calibration align with training data processing to ensure pa…
zpqiu Nov 26, 2025
4f1324a
remove l1 test; upload missed recipe yaml
zpqiu Nov 27, 2025
47ea0c0
Merge branch 'main' into kv-cache-fp8
zpqiu Nov 27, 2025
94d16ec
resolve fp8 patch conflicts
zpqiu Nov 27, 2025
48b20aa
add nightly test
zpqiu Nov 27, 2025
603c366
increase gpu hours for new nightly test
zpqiu Nov 27, 2025
7ca82f3
allow a larger logprob tolerance
zpqiu Nov 27, 2025
b34ad76
update kv_cache_dtype with choices
zpqiu Dec 1, 2025
40fa1ac
add default kv_cache_dtype; update checking logic code
zpqiu Dec 2, 2025
6d65466
Merge branch 'main' into kv-cache-fp8
zpqiu Dec 2, 2025
db3ea88
add requires_kv_scale_sync property to GenerationInterface
zpqiu Dec 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
defaults: ../../grpo_math_1B.yaml
grpo:
val_period: 20
checkpointing:
enabled: false
checkpoint_dir: results/grpo_qwen3_8b_fp8_kvcache
loss_fn:
use_importance_sampling_correction: true
policy:
model_name: Qwen/Qwen3-8B-Base
train_micro_batch_size: 1
logprob_batch_size: 1
max_total_sequence_length: 8192
dtensor_cfg:
enabled: false
optimizer: null
scheduler: null
megatron_cfg:
enabled: true
converter_type: Qwen3ForCausalLM
tensor_model_parallel_size: 4
optimizer:
lr: 1.0e-06
min_lr: 1.0e-06
weight_decay: 0.1
use_precision_aware_optimizer: false
scheduler:
lr_decay_iters: null
lr_warmup_iters: 10
lr_warmup_init: 1.0e-07
make_sequence_length_divisible_by: ${mul:${policy.megatron_cfg.tensor_model_parallel_size},
2}
generation:
vllm_cfg:
precision: fp8
kv_cache_dtype: fp8
use_deep_gemm: true
data:
max_input_seq_length: 2048
prompt_file: null
dataset_name: DAPOMath17K
env:
dapo:
num_workers: 16
math:
num_workers: 16
math_verify_impl: dapo_math_verify
cluster:
gpus_per_node: 8
107 changes: 103 additions & 4 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,24 @@ def init_vllm():
assert loss_config["use_importance_sampling_correction"] is True, (
"Importance sampling must be enabled for vLLM FP8 generation for good convergence!"
)
if generation_config["vllm_cfg"].get("kv_cache_dtype") == "fp8":
# FP8 KV cache requires FP8 model precision
assert generation_config["vllm_cfg"]["precision"] == "fp8", (
"kv_cache_dtype='fp8' requires precision='fp8'. "
"FP8 KV cache can only be used together with FP8 model weights."
)
# FP8 KV cache compatibility checks
assert policy_config["dtensor_cfg"]["enabled"] == False, (
"DTensor backend is not supported with kv cache fp8 enabled."
)
assert not _should_use_async_rollouts(master_config), (
"Async rollouts is not supported with kv cache fp8 enabled."
)
assert policy_config["megatron_cfg"]["pipeline_model_parallel_size"] == 1, (
"Currently when using FP8 KV cache in generation, then in megatron we only support pipeline_model_parallel_size=1. We will add more support in future."
)

## make vllm hf overrides match the training policy
generation_config["vllm_cfg"]["hf_overrides"] = policy_config.get(
"hf_config_overrides", {}
)
Expand Down Expand Up @@ -870,12 +888,36 @@ def _should_use_penguin(master_config: MasterConfig) -> bool:
return should_use_penguin


# Function to check if KV cache scales should be calculated and synchronized during refit
def _should_sync_kv_scales(master_config: MasterConfig) -> bool:
Comment thread
terrykong marked this conversation as resolved.
Outdated
"""Check if KV cache scales should be synchronized during refit.

Returns True if kv_cache_dtype is fp8 (which requires precision=fp8).
KV scales are always computed and synced statically during training
when using FP8 KV cache.
"""
generation_config = master_config["policy"]["generation"]
if generation_config is None:
return False

backend = generation_config.get("backend", "")
if backend != "vllm":
return False

vllm_cfg = generation_config.get("vllm_cfg", {})
kv_cache_dtype = vllm_cfg.get("kv_cache_dtype", "auto")
Comment thread
terrykong marked this conversation as resolved.
Outdated

# Sync scales when using FP8 KV cache (always static in this design)
return kv_cache_dtype == "fp8"


def refit_policy_generation(
policy: ColocatablePolicyInterface,
policy_generation: GenerationInterface,
colocated_inference: bool,
_refit_buffer_size_gb: Optional[int] = None,
timer: Optional[Timer] = None,
kv_scales: Optional[dict[str, float]] = None,
) -> None:
"""Refit the policy generation interface with the latest policy weights.

Expand All @@ -886,6 +928,7 @@ def refit_policy_generation(
If it is None, the buffer size will be computed by the remaining memory.
This parameter is primarily used for testing.
timer: Optional Timer used to time the prepare/transfer/update phase
kv_scales: Optional dictionary of KV cache scales for FP8 quantization.
"""
if colocated_inference:
policy.offload_before_refit()
Expand Down Expand Up @@ -913,7 +956,7 @@ def refit_policy_generation(
)

futures_train = policy.stream_weights_via_ipc_zmq(
buffer_size_bytes=buffer_size_bytes
buffer_size_bytes=buffer_size_bytes, kv_scales=kv_scales
)
futures_inference = policy_generation.update_weights_via_ipc_zmq()
# wait for all futures to complete
Expand All @@ -922,7 +965,7 @@ def refit_policy_generation(
update_success = all(result for result in results if result is not None)
else:
# update weights through nccl
futures_train = policy.broadcast_weights_for_collective()
futures_train = policy.broadcast_weights_for_collective(kv_scales=kv_scales)
futures_inference = policy_generation.update_weights_from_collective()
# wait for all futures to complete
ray.get(futures_train)
Expand Down Expand Up @@ -972,6 +1015,10 @@ def grpo_train(
)
timeout.start_iterations()

# Check if we need to sync KV cache scales (infer from config)
sync_kv_scales = _should_sync_kv_scales(master_config)
kv_scales_cache = None # Cache reused for computed kv scales

NEED_REFIT = True
# If policy_generation is None, use the policy as the generation interface (megatron framework backend)
if policy_generation is None:
Expand Down Expand Up @@ -1001,6 +1048,7 @@ def grpo_train(
colocated_inference = master_config["policy"]["generation"]["colocated"]["enabled"]

# Run validation at the start if configured
# TODO: Add validation with kv scales if needed
Comment thread
terrykong marked this conversation as resolved.
if val_at_start and current_step == 0:
print("\n🔍 Running initial validation...", flush=True)
if NEED_REFIT and POLICY_GENERATION_STALE:
Expand Down Expand Up @@ -1062,8 +1110,43 @@ def grpo_train(
)
with timer.time("prepare_for_generation/total"):
if NEED_REFIT and POLICY_GENERATION_STALE:
# Compute KV scales if needed for FP8 quantization
Comment thread
terrykong marked this conversation as resolved.
if sync_kv_scales and kv_scales_cache is None:
print("▶ Computing KV cache scales...", flush=True)
policy.prepare_for_lp_inference()
# Align with training data processing to ensure parallel training compatibility
calib_flat, calib_input_lengths = (
batched_message_log_to_flat_message(
repeated_batch["message_log"],
pad_value_dict={
"token_ids": tokenizer.pad_token_id
},
make_sequence_length_divisible_by=master_config[
"policy"
]["make_sequence_length_divisible_by"],
)
)
# Create calibration data from flattened messages
calibration_data = BatchedDataDict[ClippedPGLossDataDict](
{
"input_ids": calib_flat["token_ids"],
"input_lengths": calib_input_lengths,
}
)
calibration_data.update(
calib_flat.get_multimodal_dict(as_tensors=False)
)
calibration_data.to("cpu")
kv_scales_cache = policy.calibrate_qkv_fp8_scales(
calibration_data, include_q=True
)["layers"]

refit_policy_generation(
policy, policy_generation, colocated_inference, timer=timer
policy,
policy_generation,
colocated_inference,
timer=timer,
kv_scales=kv_scales_cache if sync_kv_scales else None,
)
POLICY_GENERATION_STALE = False
else:
Expand Down Expand Up @@ -1266,6 +1349,19 @@ def grpo_train(
with timer.time("policy_training"):
train_results = policy.train(train_data, loss_fn)

# Recompute KV scales after policy training if needed
Comment thread
terrykong marked this conversation as resolved.
if sync_kv_scales:
with timer.time("recompute_kv_scales"):
print(
"▶ Recomputing KV cache scales after policy update...",
flush=True,
)
kv_scales_cache = policy.calibrate_qkv_fp8_scales(
train_data, include_q=True
Comment thread
terrykong marked this conversation as resolved.
)["layers"]
# Set generation as stale to force refit with new scales
POLICY_GENERATION_STALE = True

is_last_step = (total_steps + 1 >= max_num_steps) or (
(current_epoch + 1 == max_num_epochs)
and (current_step + 1 == len(dataloader))
Expand All @@ -1275,7 +1371,10 @@ def grpo_train(
if val_period > 0 and (total_steps + 1) % val_period == 0:
if NEED_REFIT and POLICY_GENERATION_STALE:
refit_policy_generation(
policy, policy_generation, colocated_inference
policy,
policy_generation,
colocated_inference,
kv_scales=kv_scales_cache if sync_kv_scales else None,
)
POLICY_GENERATION_STALE = False
else:
Expand Down
Loading
Loading