Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 25 additions & 11 deletions nemo_gym/openai_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from asyncio import sleep
from typing import (
Any,
Dict,
List,
Literal,
Expand Down Expand Up @@ -75,7 +77,7 @@
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import TypedDict

from nemo_gym.server_utils import ClientResponse, raise_for_status, request
from nemo_gym.server_utils import MAX_NUM_TRIES, ClientResponse, raise_for_status, request


########################################
Expand Down Expand Up @@ -428,7 +430,7 @@ class NeMoGymAsyncOpenAI(BaseModel):

async def _request(self, **request_kwargs: Dict) -> ClientResponse:
tries = 0
while True:
while tries < MAX_NUM_TRIES:
tries += 1
response = await request(**request_kwargs)
# See https://platform.openai.com/docs/guides/error-codes/api-errors
Expand All @@ -442,33 +444,45 @@ async def _request(self, **request_kwargs: Dict) -> ClientResponse:
else:
return response

# We've exited the loop
response.raise_for_status()

async def _raise_for_status(self, response: ClientResponse, request_kwargs: Dict[str, Any]) -> None:
if not response.ok:
print(f"Request kwargs: {json.dumps(request_kwargs)}")

await raise_for_status(response)

async def create_chat_completion(self, **kwargs):
response = await self._request(
method="POST",
request_kwargs = dict(
url=f"{self.base_url}/chat/completions",
json=kwargs,
headers={"Authorization": f"Bearer {self.api_key}"},
)
await raise_for_status(response)
response = await self._request(method="POST", **request_kwargs)

await self._raise_for_status(response, request_kwargs)
return await response.json()

async def create_response(self, **kwargs):
response = await self._request(
method="POST",
request_kwargs = dict(
url=f"{self.base_url}/responses",
json=kwargs,
headers={"Authorization": f"Bearer {self.api_key}"},
)
await raise_for_status(response)
response = await self._request(method="POST", **request_kwargs)

await self._raise_for_status(response, request_kwargs)
return await response.json()

async def create_tokenize(self, **kwargs):
base_url = self.base_url.removesuffix("/v1")
response = await self._request(
method="POST",
request_kwargs = dict(
url=f"{base_url}/tokenize",
json=kwargs,
headers={"Authorization": f"Bearer {self.api_key}"},
)
await raise_for_status(response)
response = await self._request(method="POST", **request_kwargs)

await self._raise_for_status(response, request_kwargs)
return await response.json()
2 changes: 2 additions & 0 deletions nemo_gym/rollout_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ async def _post_coroutine(row: dict) -> None:
row["responses_create_params"] = row["responses_create_params"] | config.responses_create_params
async with semaphore:
response = await server_client.post(server_name=config.agent_name, url_path="/run", json=row)
response.raise_for_status()
result = await response.json()
f.write(json.dumps(result) + "\n")
metrics.update({k: v for k, v in result.items() if isinstance(v, (int, float))})
Expand All @@ -96,6 +97,7 @@ async def run_examples(

async def _post_subroutine(row: Dict) -> Dict:
res = await server_client.post(server_name=row.pop("agent_ref")["name"], url_path="/run", json=row)
res.raise_for_status()
return await res.json()

return await tqdm.gather(*map(_post_subroutine, examples), desc="Collecting rollouts", miniters=10)
Expand Down
3 changes: 2 additions & 1 deletion nemo_gym/server_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ async def request(
async def raise_for_status(response: ClientResponse) -> None: # pragma: no cover
if not response.ok:
content = await response.content.read()
print(content)
print(f"""Request info: {response.request_info}
Response content: {content}""")
response.raise_for_status()


Expand Down
135 changes: 134 additions & 1 deletion responses_api_models/vllm_model/tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -1595,6 +1595,7 @@ def test_responses_reasoning_parser(self, monkeypatch: MonkeyPatch):
app = server.setup_webserver()
client = TestClient(app)

# START: First turn
mock_chat_completion = NeMoGymChatCompletion(
id="chtcmpl-123",
object="chat.completion",
Expand Down Expand Up @@ -1806,8 +1807,19 @@ def test_responses_reasoning_parser(self, monkeypatch: MonkeyPatch):
actual_messages = mock_method.call_args.kwargs["messages"]
assert expected_messages == actual_messages

# START: Second turn
input_messages = [
*input_messages,
*data["output"],
NeMoGymEasyInputMessage(
type="message",
role="user",
content=[NeMoGymResponseInputText(text="user", type="input_text")],
status="completed",
),
]
request_body = NeMoGymResponseCreateParamsNonStreaming(
input=input_messages + data["output"],
input=input_messages,
tools=input_tools,
)

Expand Down Expand Up @@ -1901,6 +1913,127 @@ def test_responses_reasoning_parser(self, monkeypatch: MonkeyPatch):
],
"reasoning_content": "Gathering order status and delivery info...",
},
{"content": [{"text": "user", "type": "text"}], "role": "user"},
]
actual_messages = mock_method.call_args.kwargs["messages"]
assert expected_messages == actual_messages

# START: Third turn
input_messages = [
*input_messages,
*data["output"],
NeMoGymEasyInputMessage(
type="message",
role="user",
content=[NeMoGymResponseInputText(text="user", type="input_text")],
status="completed",
),
]
request_body = NeMoGymResponseCreateParamsNonStreaming(
input=input_messages,
tools=input_tools,
)

mock_chat_completion = NeMoGymChatCompletion(
id="chtcmpl-123",
object="chat.completion",
created=FIXED_TIME,
model="dummy_model",
choices=[
NeMoGymChoice(
index=0,
finish_reason="tool_calls",
message=NeMoGymChatCompletionMessage(
role="assistant",
# Test the None path ehre
content="None reasoning test",
tool_calls=[],
reasoning_content=None,
),
)
],
)
mock_method = AsyncMock(return_value=mock_chat_completion.model_dump())
monkeypatch.setattr(
server._clients[0].__class__,
"create_chat_completion",
mock_method,
)

response = client.post(
"/v1/responses",
json=request_body.model_dump(exclude_unset=True, mode="json"),
)
assert response.status_code == 200

data = response.json()

expected_response = NeMoGymResponse(
**COMMON_RESPONSE_PARAMS,
id="resp_123",
object="response",
tools=input_tools,
created_at=FIXED_TIME,
model="dummy_model",
output=[
NeMoGymResponseOutputMessage(
id="msg_123",
status="completed",
type="message",
content=[
NeMoGymResponseOutputText(
type="output_text",
text="None reasoning test",
annotations=[],
logprobs=None,
)
],
),
],
)
expected_dict = expected_response.model_dump()
assert data == expected_dict

expected_messages = [
{"content": [{"text": "Check my order status", "type": "text"}], "role": "user"},
{
"role": "assistant",
"content": "Sure, one sec.",
"tool_calls": [],
"reasoning_content": "First reasoning item",
},
{"content": [{"text": "cool", "type": "text"}], "role": "user"},
{
"role": "assistant",
"content": "I'm still checking",
"tool_calls": [],
},
{"content": [{"text": "ok", "type": "text"}], "role": "user"},
{
"role": "assistant",
"content": " hello hello",
"tool_calls": [
{
"id": "call_123",
"function": {"arguments": '{"order_id": "123"}', "name": "get_order_status"},
"type": "function",
},
{
"id": "call_234",
"function": {"arguments": '{"order_id": "234"}', "name": "get_delivery_date"},
"type": "function",
},
],
"reasoning_content": "Gathering order status and delivery info...",
},
{"content": [{"text": "user", "type": "text"}], "role": "user"},
{
"role": "assistant",
"content": "",
"tool_calls": [],
"reasoning_content": "None content test",
},
{"content": [{"text": "user", "type": "text"}], "role": "user"},
]
actual_messages = mock_method.call_args.kwargs["messages"]
assert expected_messages == actual_messages
Expand Down