Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions nemo_gym/rollout_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ class RolloutCollectionConfig(BaseNeMoGymCLIConfig):
```
"""

agent_name: str = Field(description="The agent to collect rollouts from.")
agent_name: Optional[str] = Field(
default=None,
description="The agent to collect rollouts from. If not specified, uses agent_ref from each data row.",
)
input_jsonl_fpath: str = Field(
description="The input data source to use to collect rollouts, in the form of a file path to a jsonl file."
)
Expand Down Expand Up @@ -103,13 +106,23 @@ async def run_from_config(self, config: RolloutCollectionConfig):
if config.responses_create_params:
print(f"Overriding responses_create_params fields with {config.responses_create_params}")

# Validate all rows have an agent specified (either via config or agent_ref in data)
if not config.agent_name:
missing_agent_indices = [idx for idx, row in enumerate(rows) if not row.get("agent_ref", {}).get("name")]
if missing_agent_indices:
raise ValueError(
f"No agent specified for rows {missing_agent_indices}. Either provide +agent_name config or include agent_ref in data."
)

metrics = Counter()
with open(config.output_jsonl_fpath, "a") as f:

async def _post_coroutine(row: dict) -> None:
row["responses_create_params"] = row["responses_create_params"] | config.responses_create_params
# Use config.agent_name if specified, otherwise use agent_ref from the row
agent_name = config.agent_name or row.get("agent_ref", {}).get("name")
async with semaphore:
response = await server_client.post(server_name=config.agent_name, url_path="/run", json=row)
response = await server_client.post(server_name=agent_name, url_path="/run", json=row)
await raise_for_status(response)
result = await response.json()
f.write(json.dumps(result) + "\n")
Expand Down