diff --git a/resources_servers/aviary/schemas.py b/resources_servers/aviary/schemas.py index a192ad7d6..ca504969c 100644 --- a/resources_servers/aviary/schemas.py +++ b/resources_servers/aviary/schemas.py @@ -70,8 +70,8 @@ class AviaryStepResponse(BaseModel): class AviaryNeMoGymResponse(NeMoGymResponse): env_id: str group_id: str - contains_transitions: Literal[True] = True - output: list[list[NeMoGymResponseOutputItem]] + contains_transitions: bool + output: list[NeMoGymResponseOutputItem] | list[list[NeMoGymResponseOutputItem]] class AviaryAgentVerifyRequest(BaseVerifyRequest): diff --git a/responses_api_agents/aviary_agent/app.py b/responses_api_agents/aviary_agent/app.py index f515de08e..b04192a09 100644 --- a/responses_api_agents/aviary_agent/app.py +++ b/responses_api_agents/aviary_agent/app.py @@ -30,6 +30,7 @@ NeMoGymResponseCreateParamsNonStreaming, NeMoGymResponseFunctionToolCall, NeMoGymResponseInput, + NeMoGymResponseOutputItem, NeMoGymResponseOutputMessage, ) from resources_servers.aviary.schemas import ( @@ -54,6 +55,11 @@ class AviaryAgentConfig(BaseResponsesAPIAgentConfig): description="The maximum number of steps to take in the environment. " "If not set, the agent will run indefinitely.", ) + return_transitions: bool = Field( + default=True, + description="If True, return a list of transitions, instead of the " + "whole trajectory as a single NeMoGymResponseOutputItem.", + ) # Doesn't cause an issue if not set, but if it is, then # we can avoid sending requests that are guaranteed to @@ -146,6 +152,7 @@ async def responses(self, req: AviaryAgentRunRequest) -> AviaryNeMoGymResponse: env_id = seed_session_response.env_id model_response: NeMoGymResponse | None = None agent_state_history: list[NeMoGymResponseInput] = [] + all_messages: list[NeMoGymResponseOutputItem] = [] model_server_cookies = None step = 0 @@ -217,7 +224,12 @@ async def responses(self, req: AviaryAgentRunRequest) -> AviaryNeMoGymResponse: done = env_response.done agent_state = self.update_agent_state(agent_state, model_output, obs, successful_transition) - agent_state_history.append(cast(NeMoGymResponseInput, agent_state.input)) + if self.config.return_transitions: + agent_state_history.append(cast(NeMoGymResponseInput, agent_state.input)) + else: + all_messages.extend(model_output) + if successful_transition: + all_messages.extend(obs) if done: break @@ -231,10 +243,13 @@ async def responses(self, req: AviaryAgentRunRequest) -> AviaryNeMoGymResponse: "Rollout crashed or terminated before first transition completed, cannot proceed." ) - output = AviaryNeMoGymResponse.model_validate( - model_response.model_dump() - | {"output": agent_state_history, "env_id": env_id, "group_id": str(req.task_idx)} - ) + output_overrides = { + "env_id": env_id, + "group_id": str(req.task_idx), + "contains_transitions": self.config.return_transitions, + "output": agent_state_history if self.config.return_transitions else all_messages, + } + output = AviaryNeMoGymResponse.model_validate(model_response.model_dump() | output_overrides) return output async def run(self, body: AviaryAgentRunRequest) -> AviaryAgentVerifyResponse: diff --git a/responses_api_agents/aviary_agent/tests/test_app.py b/responses_api_agents/aviary_agent/tests/test_app.py index a76af567b..28f35dbd5 100644 --- a/responses_api_agents/aviary_agent/tests/test_app.py +++ b/responses_api_agents/aviary_agent/tests/test_app.py @@ -308,6 +308,100 @@ async def test_responses_multi_step(self) -> None: assert calls[4][1]["json"]["action"][0]["call_id"] == "call_2" assert calls[5] == call(server_name="my resources name", url_path="/close", json={"env_id": env_id}) + async def test_responses_return_transitions_false(self) -> None: + config = AviaryAgentConfig( + host="0.0.0.0", + port=8080, + entrypoint="", + name="", + model_server=ModelServerRef( + type="responses_api_models", + name="my model name", + ), + resources_server=ResourcesServerRef( + type="resources_servers", + name="my resources name", + ), + return_transitions=False, + ) + agent = AviaryAgent(config=config, server_client=MagicMock(spec=ServerClient)) + + env_id = str(uuid.uuid4()) + mock_seed_session_data = {"env_id": env_id, "obs": [{"role": "user", "content": "Step 0"}], "tools": []} + + mock_response_1 = { + "id": "resp_1", + "created_at": 1753983920.0, + "model": "dummy_model", + "object": "response", + "output": [ + NeMoGymResponseFunctionToolCall( + call_id="call_1", name="tool_1", arguments=json.dumps({"arg": "val1"}) + ).model_dump() + ], + "parallel_tool_calls": True, + "tool_choice": "auto", + "tools": [], + } + mock_step_1 = { + "obs": [{"type": "function_call_output", "call_id": "call_1", "output": "Result 1"}], + "reward": 0.0, + "done": False, + } + + mock_response_2 = { + "id": "resp_2", + "created_at": 1753983921.0, + "model": "dummy_model", + "object": "response", + "output": [ + NeMoGymResponseFunctionToolCall( + call_id="call_2", name="tool_2", arguments=json.dumps({"arg": "val2"}) + ).model_dump() + ], + "parallel_tool_calls": True, + "tool_choice": "auto", + "tools": [], + } + mock_step_2 = { + "obs": [{"type": "function_call_output", "call_id": "call_2", "output": "Result 2"}], + "reward": 1.0, + "done": True, + } + + mock_close_data = {"message": "Success", "success": True} + + dotjson_mock = AsyncMock() + dotjson_mock.json.side_effect = [ + mock_seed_session_data, + mock_response_1, + mock_step_1, + mock_response_2, + mock_step_2, + mock_close_data, + ] + dotjson_mock.raise_for_status = MagicMock() + dotjson_mock.cookies = MagicMock() + agent.server_client.post = AsyncMock(return_value=dotjson_mock) + + request = AviaryAgentRunRequest( + task_idx=42, responses_create_params=NeMoGymResponseCreateParamsNonStreaming(input=[]) + ) + response = await agent.responses(request) + + assert response.env_id == env_id + assert response.group_id == "42" + assert response.contains_transitions is False + assert len(response.output) == 4 + assert response.output[0].type == "function_call" + assert response.output[0].call_id == "call_1" + assert response.output[1].type == "function_call_output" + assert response.output[1].call_id == "call_1" + assert response.output[2].type == "function_call" + assert response.output[2].call_id == "call_2" + assert response.output[3].type == "function_call_output" + assert response.output[3].call_id == "call_2" + async def test_responses_collapse_old_env_states(self) -> None: config = AviaryAgentConfig( host="0.0.0.0",