Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
8b98a63
aviary integration
sidnarayanan Sep 16, 2025
6e05dd5
do not exit if there are no tool calls
sidnarayanan Sep 16, 2025
0752f9a
notebook/bixbench env
sidnarayanan Sep 17, 2025
88b759c
add tests; close endpoint
sidnarayanan Sep 17, 2025
1e540da
aviary agent tests
sidnarayanan Sep 17, 2025
7e91d34
make sure close is called
sidnarayanan Sep 17, 2025
4b3b401
skip nb test
sidnarayanan Sep 17, 2025
82b545f
Fix OpenAI ResponseReasoningItem.status property (#54)
bxyu-nvidia Sep 17, 2025
710da53
add copyright text
sidnarayanan Sep 17, 2025
a902b4e
Merge branch 'main' of github.com:NVIDIA-NeMo/Gym into aviary
sidnarayanan Sep 17, 2025
d2bdbf7
make imports absolute
sidnarayanan Sep 17, 2025
58c4fae
improve submit_solution docstring
sidnarayanan Sep 18, 2025
d358959
Merge branch 'main' of github.com:NVIDIA-NeMo/Gym into aviary-rl
sidnarayanan Sep 18, 2025
2e56b0a
bring back datasets field
sidnarayanan Sep 18, 2025
242433d
impl
bxyu-nvidia Sep 19, 2025
526bac6
Merge branch 'main' of github.com:NVIDIA-NeMo/Gym into aviary
sidnarayanan Sep 19, 2025
67f13e4
Merge remote-tracking branch 'remotes/origin/bxyu/fix-vllmmodel-white…
sidnarayanan Sep 19, 2025
ecbb11c
updating data files to contain request params
sidnarayanan Sep 19, 2025
ccc3037
rename notebook->bixbench
sidnarayanan Sep 23, 2025
6648449
remote env client app
sidnarayanan Sep 24, 2025
686aee0
explicitly pass args
sidnarayanan Sep 24, 2025
b2ebd59
allow server url/api key to fall back to env vars
sidnarayanan Sep 25, 2025
43ecf42
Training (#3)
sidnarayanan Oct 29, 2025
6f1d5c9
Merge branch 'main' into aviary
sidnarayanan Dec 15, 2025
524ab77
rm workbench csv files
sidnarayanan Dec 15, 2025
5c8b5e6
checkpoint pre-commit changes
sidnarayanan Dec 15, 2025
6403ce7
VLLMTokenizeResponse inherits from BaseModel
sidnarayanan Dec 15, 2025
1342bf8
client test
sidnarayanan Dec 15, 2025
33d149c
fix empty tool calls
sidnarayanan Dec 15, 2025
217c95d
move nemo_gym.integrations.aviary to resources_servers.aviary.data_mo…
sidnarayanan Dec 15, 2025
693371b
add copyright to client_app.py
sidnarayanan Dec 15, 2025
43b3e06
AviaryAgent defaults to ending rollout if no tool calls
sidnarayanan Dec 15, 2025
00b9413
define example files required by tests
sidnarayanan Dec 15, 2025
bbc6909
update README, fix bixbench_aviary.yaml entrypoint
sidnarayanan Dec 15, 2025
05bf4df
extend AviaryAgent tests; add README
sidnarayanan Dec 15, 2025
c2de7fd
sync with main: test_train_data_utils.py and example_session_state_mg…
sidnarayanan Dec 15, 2025
d6c29ad
add HotPotQA environment to aviary integration (#527)
cmunley1 Dec 17, 2025
8bcd6e0
rm train/valiation gms8k jsonls
sidnarayanan Dec 18, 2025
d487c13
rm total_len; do try: agent rollout, finally: env close
sidnarayanan Dec 18, 2025
aca7d17
break instead of raise if NeMoGymResponse fails validation
sidnarayanan Dec 18, 2025
1241ad2
improve README and AviaryAgentConfig documentation
sidnarayanan Dec 18, 2025
08ffad7
bixbench_train.jsonl -> bixbench_example.jsonl
sidnarayanan Dec 18, 2025
f24317a
Merge branch 'main' of github.com:NVIDIA-NeMo/Gym into aviary
sidnarayanan Dec 25, 2025
7b3051f
update README
sidnarayanan Dec 25, 2025
01c8507
remove tokenize endpoint
sidnarayanan Dec 25, 2025
2e61537
check if raw_model_response has error status
sidnarayanan Dec 25, 2025
c539562
only print request info if _GLOBAL_AIOHTTP_CLIENT_REQUEST_DEBUG
sidnarayanan Jan 7, 2026
0d2f0ab
suppress openai_utils._raise_for_status too
sidnarayanan Jan 7, 2026
2371ced
Merge branch 'main' of github.com:NVIDIA-NeMo/Gym into aviary
sidnarayanan Jan 7, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions nemo_gym/openai_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,13 @@
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import TypedDict

from nemo_gym.server_utils import MAX_NUM_TRIES, ClientResponse, raise_for_status, request
from nemo_gym.server_utils import (
_GLOBAL_AIOHTTP_CLIENT_REQUEST_DEBUG,
MAX_NUM_TRIES,
ClientResponse,
raise_for_status,
request,
)


########################################
Expand Down Expand Up @@ -466,7 +472,7 @@ async def _request(self, **request_kwargs: Dict) -> ClientResponse:
response.raise_for_status()

async def _raise_for_status(self, response: ClientResponse, request_kwargs: Dict[str, Any]) -> None:
if not response.ok:
if not response.ok and _GLOBAL_AIOHTTP_CLIENT_REQUEST_DEBUG:
print(f"Request kwargs: {json.dumps(request_kwargs)}")

await raise_for_status(response)
Expand Down
3 changes: 2 additions & 1 deletion nemo_gym/server_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,8 @@ async def request(
async def raise_for_status(response: ClientResponse) -> None: # pragma: no cover
if not response.ok:
content = await response.content.read()
print(f"""Request info: {response.request_info}
if _GLOBAL_AIOHTTP_CLIENT_REQUEST_DEBUG:
print(f"""Request info: {response.request_info}
Response content: {content}""")

try:
Expand Down
43 changes: 43 additions & 0 deletions resources_servers/aviary/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
> Keywords: Tool Use, Multi-step Reasoning, Environment Interaction, Scientific Tasks

This resource server adapts [Aviary environments](https://github.com/Future-House/aviary) into the NeMo Gym resources-server interface, so NeMo Gym agents can interact with Aviary `Environment`s. This allows one to implement tool and environment logic in Aviary, and deploy the environment for inference or training with Gym.

### Implemented servers in this folder

- **GSM8K**: `gsm8k_app.py`
- Meant primarily as an example, this implements [GSM8k](https://arxiv.org/abs/2110.14168) as a set of environments equipped with a calculator tool.
- **HotPotQA**: `hotpotqa_app.py`
- The HotPotQA environment asks agents to perform multi-hop question answering on the [HotPotQA dataset](https://aclanthology.org/D18-1259/)
- **BixBench**: `notebook_app.py`
- Implements the [BixBench dataset](https://arxiv.org/abs/2503.00096) as a set of environments that allow execution of a Jupyter notebook.
- Also serves as an example for how to implement notebook-backed environments for other scientific computational tasks.
- **Client/proxy to a remote Aviary dataset server**: `client_app.py`
- A generic interface to an Aviary `TaskDatasetServer`. Can be used to interact with any Aviary environments being served remotely.


# Example usage

Run the GSM8K Aviary resources server together with a model config:

```bash
config_paths="resources_servers/aviary/configs/gsm8k_aviary.yaml,\
responses_api_models/vllm_model/configs/vllm_model.yaml"
ng_run "+config_paths=[$config_paths]"
```

Then collect rollouts:

```bash
ng_collect_rollouts \
+agent_name=gsm8k_aviary_agent +input_jsonl_fpath=resources_servers/aviary/data/example.jsonl \
+output_jsonl_fpath=resources_servers/aviary/data/example_rollouts.jsonl
```

# Licensing information
Code: Apache 2.0

Data: MIT (GSM8k), Apache 2.0 (BixBench)

Dependencies
- nemo_gym: Apache 2.0
- aviary: Apache 2.0
Empty file.
164 changes: 164 additions & 0 deletions resources_servers/aviary/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import uuid
from abc import ABC
from collections import defaultdict
from typing import Generic, TypeVar, cast

from fastapi import FastAPI, Request
from openai.types.responses import FunctionToolParam
from pydantic import ConfigDict, Field

from aviary.core import (
Environment,
EnvStateMessage,
Message,
TaskDataset,
Tool,
ToolCall,
ToolCallFunction,
ToolRequestMessage,
ToolResponseMessage,
)
from nemo_gym.base_resources_server import SimpleResourcesServer
from nemo_gym.openai_utils import NeMoGymEasyInputMessage, NeMoGymFunctionCallOutput

from .schemas import (
AviaryAgentVerifyRequest,
AviaryAgentVerifyResponse,
AviaryCloseRequest,
AviaryCloseResponse,
AviaryEnvStateEasyInputMessage,
AviaryResourcesServerConfig,
AviarySeedSessionRequest,
AviarySeedSessionResponse,
AviaryStepRequest,
AviaryStepResponse,
)


logger = logging.getLogger(__name__)


TEnv = TypeVar("TEnv", bound=Environment)
TDataset = TypeVar("TDataset", bound=TaskDataset)


def tool_to_function_tool_param(tool: Tool) -> FunctionToolParam:
tool_dump = tool.info.model_dump()
tool_dump["parameters"].setdefault("additionalProperties", False)
return FunctionToolParam(type="function", strict=True, **tool_dump)


def obs_msg_to_nemo_gym(obs: Message) -> list[NeMoGymEasyInputMessage]:
# This does some Qwen3-specific things:
# 1. if content is a JSON list, we flatten it to a list of messages. Qwen3's
# chat template doesn't support messages with list contents
# 2. if content contains images (or really any other media), we drop it for now.
# Most of this is what we'd call a HACK.

is_env_state = isinstance(obs, EnvStateMessage) or (obs.info or {}).get("is_env_state", False)

dump = obs.model_dump()
try:
content: str | list = json.loads(dump["content"])
except json.JSONDecodeError:
content = dump["content"]

flat_content: list[str] = []
if isinstance(content, list):
flat_content = [c["text"] for c in content if c["type"] == "text"]
else:
flat_content = [content]

message_cls = AviaryEnvStateEasyInputMessage if is_env_state else NeMoGymEasyInputMessage
return [message_cls.model_validate(dump | {"content": c}) for c in flat_content]


class AviaryResourcesServer(SimpleResourcesServer, Generic[TEnv, TDataset], ABC):
model_config = ConfigDict(arbitrary_types_allowed=True)

config: AviaryResourcesServerConfig
dataset: TDataset
env_id_to_env: dict[str, TEnv] = Field(default_factory=dict)
env_id_to_total_reward: dict[str, float] = Field(default_factory=lambda: defaultdict(float))

def setup_webserver(self) -> FastAPI:
app = super().setup_webserver()
app.post("/step")(self.step)
app.post("/close")(self.close)
return app

async def seed_session(self, request: Request, body: AviarySeedSessionRequest) -> AviarySeedSessionResponse:
"""
Wraps creation of the Aviary environment and calling reset().
"""
env_id = str(uuid.uuid4())
env = cast(Environment, self.dataset.get_new_env_by_idx(body.task_idx))
self.env_id_to_env[env_id] = env

obs, tools = await env.reset()
return AviarySeedSessionResponse(
env_id=env_id,
obs=[message for o in obs for message in obs_msg_to_nemo_gym(o)],
tools=[tool_to_function_tool_param(t) for t in tools],
)

async def step(self, request: Request, body: AviaryStepRequest) -> AviaryStepResponse:
"""
Wraps calling step().
"""
try:
env = self.env_id_to_env[body.env_id]

action = ToolRequestMessage(
content=None,
tool_calls=[
ToolCall(id=a.call_id, function=ToolCallFunction(name=a.name, arguments=json.loads(a.arguments)))
for a in body.action
],
)
obs, reward, done, _ = await env.step(action)

self.env_id_to_total_reward[body.env_id] += reward

nemo_obs = [
message
for o in obs
for message in (
[NeMoGymFunctionCallOutput(call_id=o.tool_call_id, output=o.content)]
if isinstance(o, ToolResponseMessage)
else obs_msg_to_nemo_gym(o)
)
]
except Exception:
logger.exception("Error in step")
raise

return AviaryStepResponse(obs=nemo_obs, reward=reward, done=done)

async def verify(self, request: Request, body: AviaryAgentVerifyRequest) -> AviaryAgentVerifyResponse:
return AviaryAgentVerifyResponse(**body.model_dump(), reward=self.env_id_to_total_reward[body.response.env_id])

async def close(self, request: Request, body: AviaryCloseRequest) -> AviaryCloseResponse:
"""
Closes and deregisters body.env_id.
"""
try:
await self.env_id_to_env.pop(body.env_id).close()
except Exception as e:
return AviaryCloseResponse(message=repr(e), success=False)
return AviaryCloseResponse(message="Success", success=True)
64 changes: 64 additions & 0 deletions resources_servers/aviary/client_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os

from pydantic import field_validator, model_validator

from aviary.core import TaskDatasetClient, TaskEnvironmentClient
from resources_servers.aviary.app import AviaryResourcesServer, AviaryResourcesServerConfig


class AviaryClientResourcesServerConfig(AviaryResourcesServerConfig):
server_url: str | None = None
request_timeout: float | None = 300.0
api_key: str | None = None

@field_validator("server_url")
@classmethod
def validate_server_url(cls, v: str | None) -> str | None:
if v is None:
return os.getenv("AVIARY_SERVER_URL")
return v

@field_validator("api_key")
@classmethod
def validate_api_key(cls, v: str | None) -> str | None:
if v is None:
# Note that if AVIARY_API_KEY is not set, we will return
# None, which is also valid - assuming the server does
# not have auth enabled.
return os.getenv("AVIARY_SERVER_API_KEY")
return v


class AviaryClientResourcesServer(AviaryResourcesServer[TaskEnvironmentClient, TaskDatasetClient]):
config: AviaryClientResourcesServerConfig
dataset: TaskDatasetClient

@model_validator(mode="before")
@classmethod
def load_dataset(cls, data: dict) -> dict:
if "dataset" not in data:
config = data["config"] = AviaryClientResourcesServerConfig.model_validate(data.get("config", {}))
data["dataset"] = TaskDatasetClient(
server_url=config.server_url,
request_timeout=config.request_timeout,
api_key=config.api_key,
catch_http_errors=True,
)
return data


if __name__ == "__main__":
AviaryClientResourcesServer.run_webserver()
40 changes: 40 additions & 0 deletions resources_servers/aviary/configs/aviary.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# NOTE: AviaryResourcesServer is an abstract base server and cannot be instantiated directly,
# so here we run the GSM8k server for tests/examples. Users should use the other configurations
# in this directory to instantiate their desired server.
gsm8k_aviary_resources_server:
resources_servers:
aviary:
entrypoint: gsm8k_app.py
domain: math
verified: false
gsm8k_aviary_agent:
responses_api_agents:
aviary_agent:
entrypoint: app.py
resources_server:
type: resources_servers
name: gsm8k_aviary_resources_server
model_server:
type: responses_api_models
name: policy_model
max_steps: 10
datasets:
- name: train
type: train
jsonl_fpath: resources_servers/aviary/data/gsm8k_train.jsonl
gitlab_identifier:
dataset_name: PLACEHOLDER
version: 0.0.1
artifact_fpath: train.jsonl
license: Apache 2.0
- name: validation
type: validation
jsonl_fpath: resources_servers/aviary/data/gsm8k_validation.jsonl
gitlab_identifier:
dataset_name: PLACEHOLDER
version: 0.0.1
artifact_fpath: train.jsonl
license: Apache 2.0
- name: example
type: example
jsonl_fpath: resources_servers/aviary/data/gsm8k_example.jsonl
16 changes: 16 additions & 0 deletions resources_servers/aviary/configs/bixbench_aviary.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
bixbench_aviary_resources_server:
resources_servers:
aviary:
entrypoint: notebook_app.py
domain: coding
verified: false
bixbench_aviary_agent:
responses_api_agents:
aviary_agent:
entrypoint: app.py
resources_server:
type: resources_servers
name: bixbench_aviary_resources_server
model_server:
type: responses_api_models
name: policy_model
Loading