Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ dist/
*.vscode/

# Test
.coverage
coverage.json
.coverage*
test_assets/

# Cache
Expand Down
1 change: 1 addition & 0 deletions examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
grpo:
num_prompts_per_step: 32
num_generations_per_prompt: 16
max_rollout_turns: 1 # for multi-turn rollouts. Math Environments just have 1 turn (answering the question)
max_num_steps: 1000000
normalize_rewards: true
use_leave_one_out_baseline: true
Expand Down
194 changes: 34 additions & 160 deletions nemo_reinforcer/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@
from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict
from nemo_reinforcer.algorithms.utils import calculate_baseline_and_std_per_prompt

from nemo_reinforcer.environments.interfaces import EnvironmentInterface
from nemo_reinforcer.environments.interfaces import (
EnvironmentInterface,
EnvironmentReturn,
)
from nemo_reinforcer.distributed.virtual_cluster import RayVirtualCluster
from nemo_reinforcer.data.interfaces import (
DatumSpec,
Expand Down Expand Up @@ -59,6 +62,7 @@
from nemo_reinforcer.utils.logger import Logger, LoggerConfig
from nemo_reinforcer.utils.timer import Timer
from nemo_reinforcer.utils.checkpoint import CheckpointManager, CheckpointingConfig
from nemo_reinforcer.experience.rollouts import run_multi_turn_rollout


# ===============================================================================
Expand All @@ -73,6 +77,7 @@ class GRPOConfig(TypedDict):
normalize_rewards: bool
use_leave_one_out_baseline: bool
val_period: int
val_batch_size: int
val_at_start: bool
checkpoint_dir: str

Expand All @@ -94,7 +99,7 @@ def _default_grpo_save_state() -> GRPOSaveState:
class MasterConfig(TypedDict):
policy: PolicyConfig
loss_fn: ClippedPGLossConfig
math_env: MathEnvConfig
env_configs: Dict[str, Any]
data: DataConfig
grpo: GRPOConfig
logger: LoggerConfig
Expand Down Expand Up @@ -283,120 +288,6 @@ def refit_policy_generation(
policy.offload_after_refit()


def generate_responses(
policy_generation: GenerationInterface,
generation_input_data: BatchedDataDict[GenerationDatumSpec],
batch: BatchedDataDict[DatumSpec],
tokenizer,
input_lengths: torch.Tensor,
include_logprobs: bool = True,
) -> Tuple[BatchedDataDict[DatumSpec], List[List[int]], Dict[str, float | int]]:
"""Generate responses from policy."""
# Generate responses
generation_outputs = policy_generation.generate(generation_input_data)

# Extract generated tokens
generated_ids = []
unpadded_sequence_lengths = generation_outputs["unpadded_sequence_lengths"]
for output_ids, input_length, total_length in zip(
generation_outputs["output_ids"], input_lengths, unpadded_sequence_lengths
):
generated_ids.append(output_ids[input_length:total_length])

generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

# Append to message log
for i, (text, input_length, total_length) in enumerate(
zip(generated_texts, input_lengths, unpadded_sequence_lengths)
):
message = {
"role": "assistant",
"content": text,
"token_ids": generation_outputs["output_ids"][i, input_length:total_length],
}

if include_logprobs and "logprobs" in generation_outputs:
message["generation_logprobs"] = generation_outputs["logprobs"][
i, input_length:total_length
]

batch["message_log"][i].append(message)

metrics = {
"mean_generation_length": (
torch.sum(unpadded_sequence_lengths) - torch.sum(input_lengths)
).item()
/ len(unpadded_sequence_lengths),
"max_seqlen": torch.max(unpadded_sequence_lengths).item(),
}

return batch, generated_ids, metrics


def calculate_rewards(
batch: BatchedDataDict[DatumSpec],
task_to_env: Dict[str, EnvironmentInterface],
) -> Tuple[torch.Tensor, List[LLMMessageLogType]]:
"""Calculate rewards for generated responses.

Args:
batch: Batch containing message_log (LLMMessageLogType) with generated responses
task_to_env: Dictionary mapping task names to their corresponding environments

Returns:
rewards: Tensor of rewards
to_env: Simplified message logs sent to environment (LLMMessageLogType format)
"""
# Extract message logs for environment
to_env = [
get_keys_from_message_log(batch["message_log"][i], ["role", "content"])
for i in range(len(batch["message_log"]))
]
task_names = [batch["task_name"][i] for i in range(len(batch["task_name"]))]

# Group messages by task type
task_groups = {}
for i, task_name in enumerate(task_names):
if task_name not in task_groups:
task_groups[task_name] = []
task_groups[task_name].append((i, to_env[i]))

# Calculate rewards for each task group concurrently
futures = []
future_to_indices = {} # Map future to its corresponding indices
for task_name, group in task_groups.items():
if task_name not in task_to_env:
raise ValueError(f"No environment found for task type: {task_name}")

# Extract indices and messages for this group
indices = [idx for idx, _ in group]
messages = [msg for _, msg in group]

# Get corresponding environment info
env_info = [batch["extra_env_info"][i] for i in indices]

# Submit task to environment and store future
future = task_to_env[task_name].step.remote(messages, env_info)
futures.append(future)
future_to_indices[future] = indices

results = ray.get(futures)
all_rewards = []
for future, result in zip(futures, results):
indices = future_to_indices[future]
_, _, task_rewards, _ = result

# Store results with their original indices
for idx, reward in zip(indices, task_rewards):
all_rewards.append((idx, reward))

# Sort results by original index to maintain order
all_rewards.sort(key=lambda x: x[0])
rewards = torch.tensor([reward for _, reward in all_rewards])

return rewards, to_env


# ===============================================================================
# Training & Validation
# ===============================================================================
Expand Down Expand Up @@ -463,7 +354,7 @@ def grpo_train(
print("▶ Preparing batch...")
with timer.time("data_processing"):
# Repeat batch items
repeated_batch = batch.repeat_interleave(
repeated_batch: BatchedDataDict[DatumSpec] = batch.repeat_interleave(
master_config["grpo"]["num_generations_per_prompt"]
)
# Convert LLMMessageLogType to FlatMessagesType for generation
Expand All @@ -472,36 +363,33 @@ def grpo_train(
pad_value_dict={"token_ids": tokenizer.pad_token_id},
)
input_ids = batched_flat["token_ids"]
# Create generation-specific input structure
generation_input_data = BatchedDataDict[GenerationDatumSpec](
{
"input_ids": input_ids,
"input_lengths": input_lengths,
}
)

# Generate responses - this updates the LLMMessageLogType in repeated_batch
print(f"▶ Generating responses for batch of size {len(input_ids)}...")
print(f"▶ Generating responses for batch of size {repeated_batch.size}...")
with timer.time("prepare_for_generation"):
if NEED_REFIT and POLICY_GENERATION_STALE:
refit_policy_generation(policy, policy_generation)
POLICY_GENERATION_STALE = False
else:
policy_generation.prepare_for_generation()

with timer.time("generation"):
repeated_batch, _, gen_metrics = generate_responses(
policy_generation,
generation_input_data,
repeated_batch,
tokenizer,
input_lengths,
repeated_batch, rollout_metrics = run_multi_turn_rollout(
policy_generation=policy_generation,
input_batch=repeated_batch,
tokenizer=tokenizer,
task_to_env=task_to_env,
max_seq_len=master_config["policy"]["max_total_sequence_length"],
max_rollout_turns=master_config["grpo"]["max_rollout_turns"],
greedy=False,
)
policy_generation.finish_generation()

# Calculate rewards & advantages based on the updated LLMMessageLogType
print("▶ Calculating rewards...")
# Calculate rewards & advantages
print("▶ Processing rewards...")
with timer.time("reward_calculation"):
rewards, _ = calculate_rewards(repeated_batch, task_to_env)
# Extract rewards from final_batch
rewards = repeated_batch["total_reward"]

print("▶ Computing advantages...")
baseline, std = calculate_baseline_and_std_per_prompt(
Expand Down Expand Up @@ -665,14 +553,14 @@ def grpo_train(
metrics[k] = np.sum(v).item()
else:
metrics[k] = np.mean(v).item()
metrics.update(gen_metrics)
metrics.update(rollout_metrics)

timing_metrics = timer.get_timing_metrics(reduction_op="sum")

print(f" • Loss: {metrics['loss']:.4f}")
print(f" • Avg Reward: {np.mean(rewards.numpy()):.4f}")
print(
f" • Mean Generation Length: {gen_metrics['mean_generation_length']:.4f}"
f" • Mean Generation Length: {rollout_metrics['mean_gen_tokens_per_sample']:.4f}"
)

print("\n⏱️ Timing:")
Expand Down Expand Up @@ -726,39 +614,25 @@ def validate(
if batch_idx >= max_batches:
break

# Convert LLMMessageLogType to FlatMessagesType for generation
batched_flat, input_lengths = batched_message_log_to_flat_message(
val_batch["message_log"],
pad_value_dict={"token_ids": tokenizer.pad_token_id},
)
# Extract input IDs
input_ids = batched_flat["token_ids"]
# Create generation-specific input structure
generation_input_data = BatchedDataDict(
{
"input_ids": input_ids,
"input_lengths": input_lengths,
}
)

# Generate responses (updates the LLMMessageLogType in batch_with_msg_logs)
val_batch, generated_ids, gen_metrics = generate_responses(
val_batch, gen_metrics = run_multi_turn_rollout(
policy_generation,
generation_input_data,
val_batch,
tokenizer,
input_lengths,
include_logprobs=False,
val_task_to_env,
max_seq_len=master_config["policy"]["max_total_sequence_length"],
max_rollout_turns=master_config["grpo"]["max_rollout_turns"],
greedy=False,
)

# Calculate rewards based on the updated LLMMessageLogType
with timer.time("reward_calculation"):
rewards, to_env = calculate_rewards(val_batch, val_task_to_env)
rewards = val_batch["total_reward"]

total_rewards.extend(rewards.tolist())
total_lengths.extend([len(ids) for ids in generated_ids])
total_lengths.append(gen_metrics["mean_gen_tokens_per_sample"])

# Collect message logs for later display
to_env = get_keys_from_message_log(
val_batch["message_log"], ["role", "content"]
)
all_message_logs.extend(to_env)

# Calculate validation metrics
Expand Down
3 changes: 3 additions & 0 deletions nemo_reinforcer/algorithms/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ def __call__(

lp_error = torch.abs(generation_logprobs - prev_logprobs) # noqa: F841 (precommit ignore for now)
mult_prob_error = masked_mean(torch.exp(lp_error), mask).item()
if mult_prob_error == 0.0:
# this sometimes gets 0 (everything masked/invalid). Doing this to avoid screwing up stats too much
mult_prob_error = 1.0

next_token_logits = next_token_logits.to(torch.float32)

Expand Down
4 changes: 4 additions & 0 deletions nemo_reinforcer/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ def rl_collate_fn(data_batch: List[DatumSpec]) -> BatchedDataDict:
idx = [datum_spec["idx"] for datum_spec in data_batch]
batch_max_length = torch.ones_like(length) * length.max()

# Extract stop_strings if present
stop_strings = [datum.get("stop_strings", None) for datum in data_batch]

output = BatchedDataDict(
message_log=message_log,
length=length,
Expand All @@ -132,6 +135,7 @@ def rl_collate_fn(data_batch: List[DatumSpec]) -> BatchedDataDict:
task_name=task_names,
idx=idx,
batch_max_length=batch_max_length,
stop_strings=stop_strings,
)
return output

Expand Down
1 change: 1 addition & 0 deletions nemo_reinforcer/data/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class DatumSpec(TypedDict):
loss_multiplier: float # multiplier for the loss for this datum. 0 to mask out (say the sample is invalid)
idx: int
task_name: Optional[str] = "default"
stop_strings: Optional[List[str]] = None # Optional stop strings for generation
__extra__: Any # This allows additional fields of any type


Expand Down
7 changes: 5 additions & 2 deletions nemo_reinforcer/data/llm_message_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,11 @@ def batched_message_log_to_flat_message(
# Create input_lengths tensor
input_lengths = []
for seq in sequenced_lists:
seq_len = next(
(v.size(0) for v in seq.values() if isinstance(v, torch.Tensor)), 0
# Find the maximum length among all tensors in the dictionary, default to 0 if none exist
# Use maximum here since there may be keys that aren't populated for all messages yet.
# For example, logprobs don't get populated for non-generated tokens until post-processing.
seq_len = max(
(v.size(0) for v in seq.values() if isinstance(v, torch.Tensor)), default=0
)
input_lengths.append(seq_len)
input_lengths_tensor = torch.tensor(input_lengths, dtype=torch.int32)
Expand Down
Loading