Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rllm/experimental/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def from_config(cls, config: DictConfig) -> "AlgorithmConfig":
return cls(
estimator=rLLMAdvantageEstimator(config.algorithm.adv_estimator),
stepwise_advantage_mode=config.rllm.stepwise_advantage.mode,
norm_adv_by_std_in_grpo=config.rllm.stepwise_advantage.get("norm_adv_by_std_in_grpo", True),
norm_adv_by_std_in_grpo=config.rllm.algorithm.get("norm_adv_by_std_in_grpo", True),
use_rllm=config.rllm.stepwise_advantage.get("use_rllm", False),
use_precomputed_advantage=config.rllm.algorithm.get("use_precomputed_advantage", False),
loss_fn=config.rllm.algorithm.get("loss_fn", None),
Expand Down
2 changes: 1 addition & 1 deletion rllm/experimental/unified_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def _validate_and_setup_configs(self):
estimator=self.rllm_config.algorithm.adv_estimator,
estimator_map=self.traj_group_adv_estimator_map, # TODO(listar2000): see if we can make this configurable in config as well
stepwise_advantage_mode=self.rllm_config.stepwise_advantage.mode,
norm_adv_by_std_in_grpo=self.rllm_config.stepwise_advantage.get("norm_adv_by_std_in_grpo", True),
norm_adv_by_std_in_grpo=self.rllm_config.algorithm.get("norm_adv_by_std_in_grpo", True),
use_rllm=self.rllm_config.algorithm.get("use_rllm", False),
use_precomputed_advantage=self.rllm_config.algorithm.get("use_precomputed_advantage", False),
loss_fn=self.rllm_config.algorithm.get("loss_fn", None),
Expand Down
59 changes: 59 additions & 0 deletions tests/unified_trainer/test_algorithm_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""
Tests for AlgorithmConfig to verify norm_adv_by_std_in_grpo is read from
rllm.algorithm (not rllm.stepwise_advantage).

See: https://github.com/rllm-org/rllm/issues/447
"""

import importlib.util
import os

import pytest
from omegaconf import OmegaConf

# Import config module directly to avoid heavy transitive deps (codetiming, verl, etc.)
_CONFIG_PATH = os.path.join(os.path.dirname(__file__), "../../rllm/experimental/common/config.py")
_spec = importlib.util.spec_from_file_location("rllm_common_config", _CONFIG_PATH)
_mod = importlib.util.module_from_spec(_spec)
_spec.loader.exec_module(_mod)
AlgorithmConfig = _mod.AlgorithmConfig


def _make_config(norm_adv_by_std_in_grpo: bool = True):
"""Build a minimal DictConfig mirroring the real rllm config structure."""
return OmegaConf.create(
{
"algorithm": {
"adv_estimator": "grpo",
},
"rllm": {
"algorithm": {
"adv_estimator": "grpo",
"norm_adv_by_std_in_grpo": norm_adv_by_std_in_grpo,
"use_precomputed_advantage": False,
"loss_fn": None,
"lr_schedule": "constant",
"warmup_steps_ratio": 0.0,
},
"stepwise_advantage": {
"mode": "broadcast",
# Intentionally omit norm_adv_by_std_in_grpo here to confirm
# the code reads from rllm.algorithm, not stepwise_advantage.
},
},
}
)


def test_norm_adv_by_std_in_grpo_true_from_algorithm():
"""norm_adv_by_std_in_grpo=True is read from rllm.algorithm, not stepwise_advantage."""
config = _make_config(norm_adv_by_std_in_grpo=True)
algo_config = AlgorithmConfig.from_config(config)
assert algo_config.norm_adv_by_std_in_grpo is True


def test_norm_adv_by_std_in_grpo_false_from_algorithm():
"""norm_adv_by_std_in_grpo=False is read from rllm.algorithm, not stepwise_advantage."""
config = _make_config(norm_adv_by_std_in_grpo=False)
algo_config = AlgorithmConfig.from_config(config)
assert algo_config.norm_adv_by_std_in_grpo is False
Loading