Skip to content
63 changes: 60 additions & 3 deletions nemo_gym/rollout_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import asyncio
import json
from asyncio import Future, Semaphore
from collections import Counter
from collections import Counter, defaultdict
from contextlib import nullcontext
from itertools import chain, repeat
from pathlib import Path
Expand Down Expand Up @@ -75,6 +75,14 @@ class RolloutCollectionConfig(BaseNeMoGymCLIConfig):
default_factory=dict,
description="Overrides for the responses_create_params e.g. temperature, max_output_tokens, etc.",
)
output_profiled_jsonl_fpath: Optional[str] = Field(
default=None,
description="Output file for reward profiled dataset (requires num_repeats > 1).",
)
pass_threshold: Optional[float] = Field(
default=None,
description="Reward threshold for pass_rate calculation. If None, pass_rate not computed.",
)


class RolloutCollectionHelper(BaseModel): # pragma: no cover
Expand All @@ -90,7 +98,14 @@ async def run_from_config(self, config: RolloutCollectionConfig):

if config.num_repeats:
previous_length = len(rows)
rows = list(chain.from_iterable(repeat(row, config.num_repeats) for row in rows))
if config.output_profiled_jsonl_fpath:
expanded = []
for prompt_idx, row in enumerate(rows):
for _ in range(config.num_repeats):
expanded.append({**row, "_prompt_index": prompt_idx})
rows = expanded
else:
rows = list(chain.from_iterable(repeat(row, config.num_repeats) for row in rows))
print(f"Repeating rows (in a pattern of abc to aabbcc) from {previous_length} to {len(rows)}!")

semaphore = nullcontext()
Expand All @@ -117,6 +132,7 @@ async def run_from_config(self, config: RolloutCollectionConfig):
)

metrics = Counter()
results = [] if config.output_profiled_jsonl_fpath else None
Path(config.output_jsonl_fpath).parent.mkdir(exist_ok=True, parents=True)
with open(config.output_jsonl_fpath, "a") as f:

Expand All @@ -128,15 +144,56 @@ async def _post_coroutine(row: dict) -> None:
response = await server_client.post(server_name=agent_name, url_path="/run", json=row)
await raise_for_status(response)
result = await get_response_json(response)
if config.output_profiled_jsonl_fpath:
result["_prompt_index"] = row.get("_prompt_index")
result["_original_row"] = {k: v for k, v in row.items() if not k.startswith("_")}
results.append(result)
f.write(json.dumps(result) + "\n")
metrics.update({k: v for k, v in result.items() if isinstance(v, (int, float))})
metrics.update(
{k: v for k, v in result.items() if isinstance(v, (int, float)) and not k.startswith("_")}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we avoid overfitting to this unwanted _task_index in the metrics and just update metrics first, add task_index, and then dump?

)

await tqdm.gather(*map(_post_coroutine, rows), desc="Collecting rollouts", miniters=tqdm_miniters)

avg_metrics = {k: v / len(rows) for k, v in metrics.items()}
avg_metrics.setdefault("reward", 0.0)
print(json.dumps(avg_metrics, indent=4))

if config.output_profiled_jsonl_fpath:
if not config.num_repeats or config.num_repeats < 2:
print("Warning: output_profiled_jsonl_fpath requires num_repeats >= 2. Skipping profiling.")
else:
grouped = defaultdict(list)
for result in results:
prompt_idx = result.get("_prompt_index", 0)
grouped[prompt_idx].append(result)

Path(config.output_profiled_jsonl_fpath).parent.mkdir(exist_ok=True, parents=True)
with open(config.output_profiled_jsonl_fpath, "w") as profiled_tasks:
for prompt_idx in sorted(grouped.keys()):
task_rollouts = grouped[prompt_idx]
rewards = [r.get("reward", 0.0) for r in task_rollouts]

original_row = task_rollouts[0].get("_original_row", {})
profiled_task = {**original_row}

profiled_task["avg_reward"] = sum(rewards) / len(rewards)
profiled_task["std_reward"] = (
sum((r - profiled_task["avg_reward"]) ** 2 for r in rewards) / len(rewards)
) ** 0.5
profiled_task["min_reward"] = min(rewards)
profiled_task["max_reward"] = max(rewards)
profiled_task["total_samples"] = len(rewards)

if config.pass_threshold is not None:
passed = sum(1 for r in rewards if r >= config.pass_threshold)
profiled_task["pass_rate"] = passed / len(rewards)
profiled_task["pass_rate_total"] = len(rewards)
profiled_task["pass_rate_passed"] = passed
profiled_task["pass_threshold"] = config.pass_threshold

profiled_tasks.write(json.dumps(profiled_task) + "\n")

def run_examples(
self, examples: List[Dict], head_server_config: Optional[BaseServerConfig] = None
) -> Iterator[Future]:
Expand Down