From 829da71faadb62d16a7ac057dae29f322f3df1b0 Mon Sep 17 00:00:00 2001 From: Zakir Jiwani <108548454+JiwaniZakir@users.noreply.github.com> Date: Thu, 2 Apr 2026 10:53:31 +0000 Subject: [PATCH] Fix norm_adv_by_std_in_grpo read from algorithm not stepwise_advantage Co-Authored-By: Claude Sonnet 4.6 --- rllm/experimental/common/config.py | 2 +- rllm/experimental/unified_trainer.py | 2 +- .../unified_trainer/test_algorithm_config.py | 59 +++++++++++++++++++ 3 files changed, 61 insertions(+), 2 deletions(-) create mode 100644 tests/unified_trainer/test_algorithm_config.py diff --git a/rllm/experimental/common/config.py b/rllm/experimental/common/config.py index 5399da600..1e82fa08c 100644 --- a/rllm/experimental/common/config.py +++ b/rllm/experimental/common/config.py @@ -155,7 +155,7 @@ def from_config(cls, config: DictConfig) -> "AlgorithmConfig": return cls( estimator=rLLMAdvantageEstimator(config.algorithm.adv_estimator), stepwise_advantage_mode=config.rllm.stepwise_advantage.mode, - norm_adv_by_std_in_grpo=config.rllm.stepwise_advantage.get("norm_adv_by_std_in_grpo", True), + norm_adv_by_std_in_grpo=config.rllm.algorithm.get("norm_adv_by_std_in_grpo", True), use_rllm=config.rllm.stepwise_advantage.get("use_rllm", False), use_precomputed_advantage=config.rllm.algorithm.get("use_precomputed_advantage", False), loss_fn=config.rllm.algorithm.get("loss_fn", None), diff --git a/rllm/experimental/unified_trainer.py b/rllm/experimental/unified_trainer.py index effda0473..ea3a4868b 100644 --- a/rllm/experimental/unified_trainer.py +++ b/rllm/experimental/unified_trainer.py @@ -254,7 +254,7 @@ def _validate_and_setup_configs(self): estimator=self.rllm_config.algorithm.adv_estimator, estimator_map=self.traj_group_adv_estimator_map, # TODO(listar2000): see if we can make this configurable in config as well stepwise_advantage_mode=self.rllm_config.stepwise_advantage.mode, - norm_adv_by_std_in_grpo=self.rllm_config.stepwise_advantage.get("norm_adv_by_std_in_grpo", True), + norm_adv_by_std_in_grpo=self.rllm_config.algorithm.get("norm_adv_by_std_in_grpo", True), use_rllm=self.rllm_config.algorithm.get("use_rllm", False), use_precomputed_advantage=self.rllm_config.algorithm.get("use_precomputed_advantage", False), loss_fn=self.rllm_config.algorithm.get("loss_fn", None), diff --git a/tests/unified_trainer/test_algorithm_config.py b/tests/unified_trainer/test_algorithm_config.py new file mode 100644 index 000000000..76446f901 --- /dev/null +++ b/tests/unified_trainer/test_algorithm_config.py @@ -0,0 +1,59 @@ +""" +Tests for AlgorithmConfig to verify norm_adv_by_std_in_grpo is read from +rllm.algorithm (not rllm.stepwise_advantage). + +See: https://github.com/rllm-org/rllm/issues/447 +""" + +import importlib.util +import os + +import pytest +from omegaconf import OmegaConf + +# Import config module directly to avoid heavy transitive deps (codetiming, verl, etc.) +_CONFIG_PATH = os.path.join(os.path.dirname(__file__), "../../rllm/experimental/common/config.py") +_spec = importlib.util.spec_from_file_location("rllm_common_config", _CONFIG_PATH) +_mod = importlib.util.module_from_spec(_spec) +_spec.loader.exec_module(_mod) +AlgorithmConfig = _mod.AlgorithmConfig + + +def _make_config(norm_adv_by_std_in_grpo: bool = True): + """Build a minimal DictConfig mirroring the real rllm config structure.""" + return OmegaConf.create( + { + "algorithm": { + "adv_estimator": "grpo", + }, + "rllm": { + "algorithm": { + "adv_estimator": "grpo", + "norm_adv_by_std_in_grpo": norm_adv_by_std_in_grpo, + "use_precomputed_advantage": False, + "loss_fn": None, + "lr_schedule": "constant", + "warmup_steps_ratio": 0.0, + }, + "stepwise_advantage": { + "mode": "broadcast", + # Intentionally omit norm_adv_by_std_in_grpo here to confirm + # the code reads from rllm.algorithm, not stepwise_advantage. + }, + }, + } + ) + + +def test_norm_adv_by_std_in_grpo_true_from_algorithm(): + """norm_adv_by_std_in_grpo=True is read from rllm.algorithm, not stepwise_advantage.""" + config = _make_config(norm_adv_by_std_in_grpo=True) + algo_config = AlgorithmConfig.from_config(config) + assert algo_config.norm_adv_by_std_in_grpo is True + + +def test_norm_adv_by_std_in_grpo_false_from_algorithm(): + """norm_adv_by_std_in_grpo=False is read from rllm.algorithm, not stepwise_advantage.""" + config = _make_config(norm_adv_by_std_in_grpo=False) + algo_config = AlgorithmConfig.from_config(config) + assert algo_config.norm_adv_by_std_in_grpo is False