diff --git a/nemo_gym/rollout_collection.py b/nemo_gym/rollout_collection.py index f522d2b07..d03b8b237 100644 --- a/nemo_gym/rollout_collection.py +++ b/nemo_gym/rollout_collection.py @@ -51,7 +51,10 @@ class RolloutCollectionConfig(BaseNeMoGymCLIConfig): ``` """ - agent_name: str = Field(description="The agent to collect rollouts from.") + agent_name: Optional[str] = Field( + default=None, + description="The agent to collect rollouts from. If not specified, uses agent_ref from each data row.", + ) input_jsonl_fpath: str = Field( description="The input data source to use to collect rollouts, in the form of a file path to a jsonl file." ) @@ -103,13 +106,23 @@ async def run_from_config(self, config: RolloutCollectionConfig): if config.responses_create_params: print(f"Overriding responses_create_params fields with {config.responses_create_params}") + # Validate all rows have an agent specified (either via config or agent_ref in data) + if not config.agent_name: + missing_agent_indices = [idx for idx, row in enumerate(rows) if not row.get("agent_ref", {}).get("name")] + if missing_agent_indices: + raise ValueError( + f"No agent specified for rows {missing_agent_indices}. Either provide +agent_name config or include agent_ref in data." + ) + metrics = Counter() with open(config.output_jsonl_fpath, "a") as f: async def _post_coroutine(row: dict) -> None: row["responses_create_params"] = row["responses_create_params"] | config.responses_create_params + # Use config.agent_name if specified, otherwise use agent_ref from the row + agent_name = config.agent_name or row.get("agent_ref", {}).get("name") async with semaphore: - response = await server_client.post(server_name=config.agent_name, url_path="/run", json=row) + response = await server_client.post(server_name=agent_name, url_path="/run", json=row) await raise_for_status(response) result = await response.json() f.write(json.dumps(result) + "\n")