Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions resources_servers/aviary/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ class AviaryStepResponse(BaseModel):
class AviaryNeMoGymResponse(NeMoGymResponse):
env_id: str
group_id: str
contains_transitions: Literal[True] = True
output: list[list[NeMoGymResponseOutputItem]]
contains_transitions: bool
output: list[NeMoGymResponseOutputItem] | list[list[NeMoGymResponseOutputItem]]


class AviaryAgentVerifyRequest(BaseVerifyRequest):
Expand Down
25 changes: 20 additions & 5 deletions responses_api_agents/aviary_agent/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
NeMoGymResponseCreateParamsNonStreaming,
NeMoGymResponseFunctionToolCall,
NeMoGymResponseInput,
NeMoGymResponseOutputItem,
NeMoGymResponseOutputMessage,
)
from resources_servers.aviary.schemas import (
Expand All @@ -54,6 +55,11 @@ class AviaryAgentConfig(BaseResponsesAPIAgentConfig):
description="The maximum number of steps to take in the environment. "
"If not set, the agent will run indefinitely.",
)
return_transitions: bool = Field(
default=True,
description="If True, return a list of transitions, instead of the "
"whole trajectory as a single NeMoGymResponseOutputItem.",
)

# Doesn't cause an issue if not set, but if it is, then
# we can avoid sending requests that are guaranteed to
Expand Down Expand Up @@ -146,6 +152,7 @@ async def responses(self, req: AviaryAgentRunRequest) -> AviaryNeMoGymResponse:
env_id = seed_session_response.env_id
model_response: NeMoGymResponse | None = None
agent_state_history: list[NeMoGymResponseInput] = []
all_messages: list[NeMoGymResponseOutputItem] = []
model_server_cookies = None

step = 0
Expand Down Expand Up @@ -217,7 +224,12 @@ async def responses(self, req: AviaryAgentRunRequest) -> AviaryNeMoGymResponse:
done = env_response.done

agent_state = self.update_agent_state(agent_state, model_output, obs, successful_transition)
agent_state_history.append(cast(NeMoGymResponseInput, agent_state.input))
if self.config.return_transitions:
agent_state_history.append(cast(NeMoGymResponseInput, agent_state.input))
else:
all_messages.extend(model_output)
if successful_transition:
all_messages.extend(obs)

if done:
break
Expand All @@ -231,10 +243,13 @@ async def responses(self, req: AviaryAgentRunRequest) -> AviaryNeMoGymResponse:
"Rollout crashed or terminated before first transition completed, cannot proceed."
)

output = AviaryNeMoGymResponse.model_validate(
model_response.model_dump()
| {"output": agent_state_history, "env_id": env_id, "group_id": str(req.task_idx)}
)
output_overrides = {
"env_id": env_id,
"group_id": str(req.task_idx),
"contains_transitions": self.config.return_transitions,
"output": agent_state_history if self.config.return_transitions else all_messages,
}
output = AviaryNeMoGymResponse.model_validate(model_response.model_dump() | output_overrides)
return output

async def run(self, body: AviaryAgentRunRequest) -> AviaryAgentVerifyResponse:
Expand Down
94 changes: 94 additions & 0 deletions responses_api_agents/aviary_agent/tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,100 @@ async def test_responses_multi_step(self) -> None:
assert calls[4][1]["json"]["action"][0]["call_id"] == "call_2"
assert calls[5] == call(server_name="my resources name", url_path="/close", json={"env_id": env_id})

async def test_responses_return_transitions_false(self) -> None:
config = AviaryAgentConfig(
host="0.0.0.0",
port=8080,
entrypoint="",
name="",
model_server=ModelServerRef(
type="responses_api_models",
name="my model name",
),
resources_server=ResourcesServerRef(
type="resources_servers",
name="my resources name",
),
return_transitions=False,
)
agent = AviaryAgent(config=config, server_client=MagicMock(spec=ServerClient))

env_id = str(uuid.uuid4())
mock_seed_session_data = {"env_id": env_id, "obs": [{"role": "user", "content": "Step 0"}], "tools": []}

mock_response_1 = {
"id": "resp_1",
"created_at": 1753983920.0,
"model": "dummy_model",
"object": "response",
"output": [
NeMoGymResponseFunctionToolCall(
call_id="call_1", name="tool_1", arguments=json.dumps({"arg": "val1"})
).model_dump()
],
"parallel_tool_calls": True,
"tool_choice": "auto",
"tools": [],
}
mock_step_1 = {
"obs": [{"type": "function_call_output", "call_id": "call_1", "output": "Result 1"}],
"reward": 0.0,
"done": False,
}

mock_response_2 = {
"id": "resp_2",
"created_at": 1753983921.0,
"model": "dummy_model",
"object": "response",
"output": [
NeMoGymResponseFunctionToolCall(
call_id="call_2", name="tool_2", arguments=json.dumps({"arg": "val2"})
).model_dump()
],
"parallel_tool_calls": True,
"tool_choice": "auto",
"tools": [],
}
mock_step_2 = {
"obs": [{"type": "function_call_output", "call_id": "call_2", "output": "Result 2"}],
"reward": 1.0,
"done": True,
}

mock_close_data = {"message": "Success", "success": True}

dotjson_mock = AsyncMock()
dotjson_mock.json.side_effect = [
mock_seed_session_data,
mock_response_1,
mock_step_1,
mock_response_2,
mock_step_2,
mock_close_data,
]
dotjson_mock.raise_for_status = MagicMock()
dotjson_mock.cookies = MagicMock()
agent.server_client.post = AsyncMock(return_value=dotjson_mock)

request = AviaryAgentRunRequest(
task_idx=42, responses_create_params=NeMoGymResponseCreateParamsNonStreaming(input=[])
)
response = await agent.responses(request)

assert response.env_id == env_id
assert response.group_id == "42"
assert response.contains_transitions is False
assert len(response.output) == 4
assert response.output[0].type == "function_call"
assert response.output[0].call_id == "call_1"
assert response.output[1].type == "function_call_output"
assert response.output[1].call_id == "call_1"
assert response.output[2].type == "function_call"
assert response.output[2].call_id == "call_2"
assert response.output[3].type == "function_call_output"
assert response.output[3].call_id == "call_2"

async def test_responses_collapse_old_env_states(self) -> None:
config = AviaryAgentConfig(
host="0.0.0.0",
Expand Down