Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
9173c02
refactor to improved liimt handling
bxyu-nvidia Sep 19, 2025
ac9dbe5
use httpx-aiohttp
bxyu-nvidia Sep 20, 2025
1039075
be loop aware
bxyu-nvidia Sep 20, 2025
81822c2
fixes
bxyu-nvidia Sep 20, 2025
933788d
lazy init
bxyu-nvidia Sep 20, 2025
31c3453
idk
bxyu-nvidia Sep 20, 2025
a6da219
add back origin based httpx client
bxyu-nvidia Sep 20, 2025
cf3efb9
skip tokenize call
bxyu-nvidia Sep 21, 2025
1846e70
quiet httpx logs
bxyu-nvidia Sep 21, 2025
e9d28ba
improve rollout collection efficiency
bxyu-nvidia Sep 21, 2025
a65d389
try make model response read more efficient
bxyu-nvidia Sep 21, 2025
567cc00
improve server utils
bxyu-nvidia Sep 21, 2025
86c5a34
set keepalive expiry very large
bxyu-nvidia Sep 21, 2025
31ba06a
retry in server client
bxyu-nvidia Sep 21, 2025
056be91
comment
bxyu-nvidia Sep 21, 2025
b24bdd4
revert read
bxyu-nvidia Sep 21, 2025
d49225a
comment
bxyu-nvidia Sep 21, 2025
618e7bd
clean headers
bxyu-nvidia Sep 21, 2025
4350f5c
use model validate
bxyu-nvidia Sep 21, 2025
6ac625e
try set ulimit
bxyu-nvidia Sep 21, 2025
be15207
idk tweak
bxyu-nvidia Sep 21, 2025
ff25bd3
remove this ulimit
bxyu-nvidia Sep 21, 2025
33f691f
start refactor to only use aiohttp
bxyu-nvidia Sep 21, 2025
ab03285
start refactor
bxyu-nvidia Sep 21, 2025
c0c09eb
fixes
bxyu-nvidia Sep 21, 2025
ef6c43f
fixes
bxyu-nvidia Sep 21, 2025
fb3a4c5
fix name
bxyu-nvidia Sep 21, 2025
a5a5aa0
test fixes
bxyu-nvidia Sep 21, 2025
2e11e45
fixes
bxyu-nvidia Sep 21, 2025
cc6dfe9
fixes
bxyu-nvidia Sep 21, 2025
d5139c4
fixes
bxyu-nvidia Sep 21, 2025
1db456e
runtime fixes
bxyu-nvidia Sep 21, 2025
0436512
add shutdown
bxyu-nvidia Sep 21, 2025
62943c3
clean print
bxyu-nvidia Sep 21, 2025
0231e01
fixes
bxyu-nvidia Sep 21, 2025
2fe6eae
fxies
bxyu-nvidia Sep 21, 2025
21e5123
refactor into requests function
bxyu-nvidia Sep 21, 2025
afc30b0
dont increment retries for server disconnected error
bxyu-nvidia Sep 21, 2025
9659255
don't poll http statuses during server lifetime
bxyu-nvidia Sep 21, 2025
f1ab7b8
clean
bxyu-nvidia Sep 21, 2025
c83c88f
swap back to old tokenizatino flow and add comments there
bxyu-nvidia Sep 21, 2025
f951165
dont pop token ids
bxyu-nvidia Sep 21, 2025
876c032
add set global aiohttp client fn
bxyu-nvidia Sep 21, 2025
837d0aa
cleanup rollout collection
bxyu-nvidia Sep 21, 2025
6e2d73e
close and reopen
bxyu-nvidia Sep 21, 2025
582ff21
refactor to accommodate rollout collection helper
bxyu-nvidia Sep 21, 2025
2c8c6f9
pass through head server config
bxyu-nvidia Sep 21, 2025
48cb4ea
filter out 200 ok messages to clean up server logs
bxyu-nvidia Sep 21, 2025
597ccee
add faq
bxyu-nvidia Sep 21, 2025
fb4a4fc
empty commit for qa
bxyu-nvidia Sep 21, 2025
a3af3bb
add aiohttp package
bxyu-nvidia Sep 21, 2025
c8f240a
improve readme
bxyu-nvidia Sep 22, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
- [FAQ: build-docs / Build docs CI failures](#faq-build-docs--build-docs-ci-failures)
- [FAQ: NeMo Gym, training frameworks, and token IDs](#faq-nemo-gym-training-frameworks-and-token-ids)
- [FAQ: NeMo Gym what CI/CD do I need to pass?](#faq-nemo-gym-what-cicd-do-i-need-to-pass)
- [FAQ: Why aiohttp backend and not httpx/httpcore for async http?](#faq-why-aiohttp-backend-and-not-httpxhttpcore-for-async-http)


# NeMo-Gym
Expand Down Expand Up @@ -886,3 +887,43 @@ Examples of PR checks that most PRs do not need to wait for to pass:
1. CICD NeMo / cicd-container-build / build / main (push)
2. CICD NeMo / Nemo_CICD_Test (push)
...


# FAQ: Why aiohttp backend and not httpx/httpcore for async http?

TL;DR: httpx is O(n^2) runtime where n is the number of queued requests (i.e. for each request, we check all other queued requests). This is terribly inefficient and results in major slowdowns.

On Wed Sep 17, 2025, inspired by the Deepseek R1 Nature paper, we tried launching a larger rollout batch run with up to 16 off policy steps in NeMo RL. Our setting resulted in Gym being slammed with 16k concurrent requests. At the time, we were using a single Gym instance with multiple data-parallel vLLM workers, and that setup hung for 40 minutes before the first request was processed. Something was wrong.

Before that time, we had also gotten reports that the rollout collection in Gym couldn't be used with high concurrency i.e. in some cases people had to set the concurrency to 32 requests in parallel. Putting these two data points together, we figured something was wrong with the concurrency setup in Gym.

For some context, Gym is a set of servers that end up calling a model endpoint server at some point. And it's really important that we never artificially restrict the concurrency in the Gym side since technically we are always clients of that model endpoint server, since the model endpoint server could handle many more requests than we're restricting the concurrency to. So we always want Gym to be as efficient as possible and not have e.g. max parallel requests or smth parameter in Gym.

Eventually, we isolated the issue to our async http backend -- httpx and httpcore. We originally decided to use httpx for the async http backend in Gym because the OpenAI client uses it by default so we can share the same backend http client. Unfortunately, the httpcore connection pool subroutine for pooling connections over requests is O(n^2) where n is the number of queued requests.

Networking mental model:
1. A request is sent by Gym to the model endpoint server.
2. This request requires a connection from our client side to the server side.
1. This connection is a socket (identified by a port) and a socket is an open file (managed by the operating system).
2. If we are sending 100 requests, in the worst case we could open 100 connections == 100 open files. This quickly becomes very expensive.
3. So, async http backends will pool requests across connections to a single endpoint, where multiple requests can leverage the same file if they are going to the same endpoint origin.
4. This is called connection pooling. And it's possible that all 100 requests share a single connection.
3. But this connection pooling now needs some management logic. When the client sends a new request, it needs to determine if that request can reuse an existing connection.
1. And this is where the httpcore connection pool logic is very inefficient.

Here are the key calls in the stack trace:
1. OpenAI client at some point calls httpx client
2. httpx client calls into the transport [here](https://github.com/encode/httpx/blob/4b23574cf83307ce27d3b14b4a425dc58c57d28d/httpx/_client.py#L1014)
3. Transport calls into httpcore connection pool [here](https://github.com/encode/httpx/blob/4b23574cf83307ce27d3b14b4a425dc58c57d28d/httpx/_transports/default.py#L250)
4. For each request, the httpcore connection pool calls this `_assign_requests_to_connections` subroutine [here](https://github.com/encode/httpcore/blob/5974b03c7df89d3ee4e23779900d5349d550753c/httpcore/_async/connection_pool.py#L228)
1. This subroutine loops through connections [here](https://github.com/encode/httpcore/blob/5974b03c7df89d3ee4e23779900d5349d550753c/httpcore/_async/connection_pool.py#L284)
2. and loops through queued requests [here](https://github.com/encode/httpcore/blob/5974b03c7df89d3ee4e23779900d5349d550753c/httpcore/_async/connection_pool.py#L303)
3. Which results in a total of O(n^2) runtime if the number of queued requests is large. Which is always the case if we slam with some larger number of requests.

In the end, we decided to swap our http backend from httpx to aiohttp since we had good prior experience working with aiohttp in production infra.

Here are some Github issues related to this problem. They didn't help too much, but they did validate our solution (kind of) to use aiohttp as as async http backend instead.
- https://github.com/openai/openai-python/issues/1596
- https://github.com/encode/httpx/issues/3215#issuecomment-2220795088

If you are using AsyncOpenAI client with a parallelism > 32, you may also want to check if this kind of inefficiency also affects your setup.
5 changes: 1 addition & 4 deletions nemo_gym/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ def wait_for_spinup(self) -> None:
sleep(sleep_interval)

def shutdown(self) -> None:
# TODO there is possibly a better way to handle the server shutdowns.
for process_name, process in self._processes.items():
print(f"Killing `{process_name}`")
process.kill()
Expand All @@ -243,10 +244,6 @@ async def sleep():
# Indefinitely
while True:
self.poll()

statuses = self.check_http_server_statuses()
assert statuses.count("success") == len(statuses), "Found non-success statuses"

await asyncio.sleep(60)

try:
Expand Down
44 changes: 34 additions & 10 deletions nemo_gym/openai_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
Union,
)

from openai import AsyncOpenAI
from openai.types.chat import (
ChatCompletion,
ChatCompletionAssistantMessageParam,
Expand Down Expand Up @@ -75,7 +74,7 @@
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import TypedDict

from nemo_gym.server_utils import get_global_httpx_client
from nemo_gym.server_utils import request


########################################
Expand Down Expand Up @@ -420,11 +419,36 @@ class NeMoGymChatCompletionCreateParamsNonStreaming(BaseModel):
########################################


class NeMoGymAsyncOpenAI(AsyncOpenAI):
def __init__(self, **kwargs) -> None:
# TODO: this setup is take from https://github.com/NVIDIA/NeMo-Skills/blob/80dc78ac758c4cac81c83a43a729e7ca1280857b/nemo_skills/inference/model/base.py#L318
# However, there may still be a lingering issue regarding saturating at 100 max connections
kwargs["http_client"] = get_global_httpx_client(kwargs["base_url"])
kwargs["timeout"] = None # Enforce no timeout

super().__init__(**kwargs)
class NeMoGymAsyncOpenAI(BaseModel):
"""This is just a stub class that wraps around aiohttp"""

base_url: str
api_key: str

async def create_chat_completion(self, **kwargs):
response = await request(
method="POST",
url=f"{self.base_url}/chat/completions",
json=kwargs,
headers={"Authorization": f"Bearer {self.api_key}"},
)
return await response.json()

async def create_response(self, **kwargs):
response = await request(
method="POST",
url=f"{self.base_url}/responses",
json=kwargs,
headers={"Authorization": f"Bearer {self.api_key}"},
)
return await response.json()

async def create_tokenize(self, **kwargs):
base_url = self.base_url.removesuffix("/v1")
response = await request(
method="POST",
url=f"{base_url}/tokenize",
json=kwargs,
headers={"Authorization": f"Bearer {self.api_key}"},
)
return await response.json()
96 changes: 63 additions & 33 deletions nemo_gym/rollout_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,19 @@
from collections import Counter
from contextlib import nullcontext
from itertools import chain, repeat
from typing import Optional
from typing import Dict, List, Optional

from pydantic import BaseModel
from tqdm.asyncio import tqdm

from nemo_gym.server_utils import ServerClient, get_global_config_dict
from nemo_gym.config_types import BaseServerConfig
from nemo_gym.server_utils import (
GlobalAIOHTTPAsyncClientConfig,
ServerClient,
get_global_config_dict,
is_global_aiohttp_client_setup,
set_global_aiohttp_client,
)


class RolloutCollectionConfig(BaseModel):
Expand All @@ -34,46 +41,69 @@ class RolloutCollectionConfig(BaseModel):
num_samples_in_parallel: Optional[int] = None


async def _collect_rollouts(config: RolloutCollectionConfig): # pragma: no cover
with open(config.input_jsonl_fpath) as input_dataset:
rows = list(map(json.loads, input_dataset))
print(f"Found {len(rows)} rows!")
class RolloutCollectionHelper(BaseModel): # pragma: no cover
async def run_from_config(self, config: RolloutCollectionConfig):
with open(config.input_jsonl_fpath) as input_dataset:
rows = list(map(json.loads, input_dataset))
print(f"Found {len(rows)} rows!")

if config.limit:
previous_length = len(rows)
rows = rows[: config.limit]
print(f"Limiting rows from {previous_length} to {len(rows)}!")
if config.limit:
previous_length = len(rows)
rows = rows[: config.limit]
print(f"Limiting rows from {previous_length} to {len(rows)}!")

if config.num_repeats:
previous_length = len(rows)
rows = list(chain.from_iterable(repeat(row, config.num_repeats) for row in rows))
print(f"Repeating rows (in a pattern of abc to aabbcc) from {previous_length} to {len(rows)}!")
if config.num_repeats:
previous_length = len(rows)
rows = list(chain.from_iterable(repeat(row, config.num_repeats) for row in rows))
print(f"Repeating rows (in a pattern of abc to aabbcc) from {previous_length} to {len(rows)}!")

server_client = ServerClient.load_from_global_config()
semaphore = nullcontext()
if config.num_samples_in_parallel:
semaphore = Semaphore(config.num_samples_in_parallel)

semaphore = nullcontext()
if config.num_samples_in_parallel:
semaphore = Semaphore(config.num_samples_in_parallel)
server_client = self.setup_server_client()

async def _post_coroutine(row: dict):
async with semaphore:
return await server_client.post(server_name=config.agent_name, url_path="/run", json=row)
metrics = Counter()
with open(config.output_jsonl_fpath, "a") as f:

tasks = list(map(_post_coroutine, rows))
async def _post_coroutine(row: dict) -> None:
async with semaphore:
response = await server_client.post(server_name=config.agent_name, url_path="/run", json=row)
result = await response.json()
f.write(json.dumps(result) + "\n")
metrics.update({k: v for k, v in result.items() if isinstance(v, (int, float))})

metrics = Counter()
pbar = tqdm.as_completed(tasks, desc="Collecting rollouts")
with open(config.output_jsonl_fpath, "a") as f:
for future in pbar:
result = await future
result = result.json()
f.write(json.dumps(result) + "\n")
metrics += Counter({k: v for k, v in result.items() if isinstance(v, (int, float))})
await tqdm.gather(*map(_post_coroutine, rows), desc="Collecting rollouts")

avg_metrics = {k: v / len(tasks) for k, v in metrics.items()}
print(json.dumps(avg_metrics, indent=4))
avg_metrics = {k: v / len(rows) for k, v in metrics.items()}

print(json.dumps(avg_metrics, indent=4))

async def run_examples(
self, examples: List[Dict], head_server_config: Optional[BaseServerConfig] = None
) -> List[Dict]:
server_client = self.setup_server_client(head_server_config)

async def _post_subroutine(row: Dict) -> Dict:
res = await server_client.post(server_name=row.pop("agent_ref")["name"], url_path="/run", json=row)
return await res.json()

return await tqdm.gather(*map(_post_subroutine, examples), desc="Collecting rollouts")

def setup_server_client(self, head_server_config: Optional[BaseServerConfig] = None) -> ServerClient:
server_client = ServerClient.load_from_global_config(head_server_config)

# We set this rollout global aiohttp client to use the same max connections as the underlying head server global config.
if not is_global_aiohttp_client_setup():
set_global_aiohttp_client(
cfg=GlobalAIOHTTPAsyncClientConfig.model_validate(server_client.global_config_dict)
)

return server_client


def collect_rollouts(): # pragma: no cover
config = RolloutCollectionConfig.model_validate(get_global_config_dict())
asyncio.run(_collect_rollouts(config))
rch = RolloutCollectionHelper()

asyncio.run(rch.run_from_config(config))
Loading
Loading