diff --git a/nemo_gym/openai_utils.py b/nemo_gym/openai_utils.py index 49a4b53ed..476b9be34 100644 --- a/nemo_gym/openai_utils.py +++ b/nemo_gym/openai_utils.py @@ -11,8 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json from asyncio import sleep from typing import ( + Any, Dict, List, Literal, @@ -75,7 +77,7 @@ from pydantic import BaseModel, ConfigDict, Field from typing_extensions import TypedDict -from nemo_gym.server_utils import ClientResponse, raise_for_status, request +from nemo_gym.server_utils import MAX_NUM_TRIES, ClientResponse, raise_for_status, request ######################################## @@ -428,7 +430,7 @@ class NeMoGymAsyncOpenAI(BaseModel): async def _request(self, **request_kwargs: Dict) -> ClientResponse: tries = 0 - while True: + while tries < MAX_NUM_TRIES: tries += 1 response = await request(**request_kwargs) # See https://platform.openai.com/docs/guides/error-codes/api-errors @@ -442,33 +444,45 @@ async def _request(self, **request_kwargs: Dict) -> ClientResponse: else: return response + # We've exited the loop + response.raise_for_status() + + async def _raise_for_status(self, response: ClientResponse, request_kwargs: Dict[str, Any]) -> None: + if not response.ok: + print(f"Request kwargs: {json.dumps(request_kwargs)}") + + await raise_for_status(response) + async def create_chat_completion(self, **kwargs): - response = await self._request( - method="POST", + request_kwargs = dict( url=f"{self.base_url}/chat/completions", json=kwargs, headers={"Authorization": f"Bearer {self.api_key}"}, ) - await raise_for_status(response) + response = await self._request(method="POST", **request_kwargs) + + await self._raise_for_status(response, request_kwargs) return await response.json() async def create_response(self, **kwargs): - response = await self._request( - method="POST", + request_kwargs = dict( url=f"{self.base_url}/responses", json=kwargs, headers={"Authorization": f"Bearer {self.api_key}"}, ) - await raise_for_status(response) + response = await self._request(method="POST", **request_kwargs) + + await self._raise_for_status(response, request_kwargs) return await response.json() async def create_tokenize(self, **kwargs): base_url = self.base_url.removesuffix("/v1") - response = await self._request( - method="POST", + request_kwargs = dict( url=f"{base_url}/tokenize", json=kwargs, headers={"Authorization": f"Bearer {self.api_key}"}, ) - await raise_for_status(response) + response = await self._request(method="POST", **request_kwargs) + + await self._raise_for_status(response, request_kwargs) return await response.json() diff --git a/nemo_gym/rollout_collection.py b/nemo_gym/rollout_collection.py index 0969a993a..bdf0f2ebf 100644 --- a/nemo_gym/rollout_collection.py +++ b/nemo_gym/rollout_collection.py @@ -79,6 +79,7 @@ async def _post_coroutine(row: dict) -> None: row["responses_create_params"] = row["responses_create_params"] | config.responses_create_params async with semaphore: response = await server_client.post(server_name=config.agent_name, url_path="/run", json=row) + response.raise_for_status() result = await response.json() f.write(json.dumps(result) + "\n") metrics.update({k: v for k, v in result.items() if isinstance(v, (int, float))}) @@ -96,6 +97,7 @@ async def run_examples( async def _post_subroutine(row: Dict) -> Dict: res = await server_client.post(server_name=row.pop("agent_ref")["name"], url_path="/run", json=row) + res.raise_for_status() return await res.json() return await tqdm.gather(*map(_post_subroutine, examples), desc="Collecting rollouts", miniters=10) diff --git a/nemo_gym/server_utils.py b/nemo_gym/server_utils.py index 7760bb32e..7f9e2ac97 100644 --- a/nemo_gym/server_utils.py +++ b/nemo_gym/server_utils.py @@ -150,7 +150,8 @@ async def request( async def raise_for_status(response: ClientResponse) -> None: # pragma: no cover if not response.ok: content = await response.content.read() - print(content) + print(f"""Request info: {response.request_info} +Response content: {content}""") response.raise_for_status() diff --git a/responses_api_models/vllm_model/tests/test_app.py b/responses_api_models/vllm_model/tests/test_app.py index cbb1b6926..283b20ab3 100644 --- a/responses_api_models/vllm_model/tests/test_app.py +++ b/responses_api_models/vllm_model/tests/test_app.py @@ -1595,6 +1595,7 @@ def test_responses_reasoning_parser(self, monkeypatch: MonkeyPatch): app = server.setup_webserver() client = TestClient(app) + # START: First turn mock_chat_completion = NeMoGymChatCompletion( id="chtcmpl-123", object="chat.completion", @@ -1806,8 +1807,19 @@ def test_responses_reasoning_parser(self, monkeypatch: MonkeyPatch): actual_messages = mock_method.call_args.kwargs["messages"] assert expected_messages == actual_messages + # START: Second turn + input_messages = [ + *input_messages, + *data["output"], + NeMoGymEasyInputMessage( + type="message", + role="user", + content=[NeMoGymResponseInputText(text="user", type="input_text")], + status="completed", + ), + ] request_body = NeMoGymResponseCreateParamsNonStreaming( - input=input_messages + data["output"], + input=input_messages, tools=input_tools, ) @@ -1901,6 +1913,127 @@ def test_responses_reasoning_parser(self, monkeypatch: MonkeyPatch): ], "reasoning_content": "Gathering order status and delivery info...", }, + {"content": [{"text": "user", "type": "text"}], "role": "user"}, + ] + actual_messages = mock_method.call_args.kwargs["messages"] + assert expected_messages == actual_messages + + # START: Third turn + input_messages = [ + *input_messages, + *data["output"], + NeMoGymEasyInputMessage( + type="message", + role="user", + content=[NeMoGymResponseInputText(text="user", type="input_text")], + status="completed", + ), + ] + request_body = NeMoGymResponseCreateParamsNonStreaming( + input=input_messages, + tools=input_tools, + ) + + mock_chat_completion = NeMoGymChatCompletion( + id="chtcmpl-123", + object="chat.completion", + created=FIXED_TIME, + model="dummy_model", + choices=[ + NeMoGymChoice( + index=0, + finish_reason="tool_calls", + message=NeMoGymChatCompletionMessage( + role="assistant", + # Test the None path ehre + content="None reasoning test", + tool_calls=[], + reasoning_content=None, + ), + ) + ], + ) + mock_method = AsyncMock(return_value=mock_chat_completion.model_dump()) + monkeypatch.setattr( + server._clients[0].__class__, + "create_chat_completion", + mock_method, + ) + + response = client.post( + "/v1/responses", + json=request_body.model_dump(exclude_unset=True, mode="json"), + ) + assert response.status_code == 200 + + data = response.json() + + expected_response = NeMoGymResponse( + **COMMON_RESPONSE_PARAMS, + id="resp_123", + object="response", + tools=input_tools, + created_at=FIXED_TIME, + model="dummy_model", + output=[ + NeMoGymResponseOutputMessage( + id="msg_123", + status="completed", + type="message", + content=[ + NeMoGymResponseOutputText( + type="output_text", + text="None reasoning test", + annotations=[], + logprobs=None, + ) + ], + ), + ], + ) + expected_dict = expected_response.model_dump() + assert data == expected_dict + + expected_messages = [ + {"content": [{"text": "Check my order status", "type": "text"}], "role": "user"}, + { + "role": "assistant", + "content": "Sure, one sec.", + "tool_calls": [], + "reasoning_content": "First reasoning item", + }, + {"content": [{"text": "cool", "type": "text"}], "role": "user"}, + { + "role": "assistant", + "content": "I'm still checking", + "tool_calls": [], + }, + {"content": [{"text": "ok", "type": "text"}], "role": "user"}, + { + "role": "assistant", + "content": " hello hello", + "tool_calls": [ + { + "id": "call_123", + "function": {"arguments": '{"order_id": "123"}', "name": "get_order_status"}, + "type": "function", + }, + { + "id": "call_234", + "function": {"arguments": '{"order_id": "234"}', "name": "get_delivery_date"}, + "type": "function", + }, + ], + "reasoning_content": "Gathering order status and delivery info...", + }, + {"content": [{"text": "user", "type": "text"}], "role": "user"}, + { + "role": "assistant", + "content": "", + "tool_calls": [], + "reasoning_content": "None content test", + }, + {"content": [{"text": "user", "type": "text"}], "role": "user"}, ] actual_messages = mock_method.call_args.kwargs["messages"] assert expected_messages == actual_messages