From 3b40c585722bdfe733f11ed4b08d6bcdd386de58 Mon Sep 17 00:00:00 2001 From: casinca <47400729+casinca@users.noreply.github.com> Date: Fri, 20 Mar 2026 22:51:35 +0100 Subject: [PATCH 01/11] fix(_get_tool_suffix_ids): keeping tokens after EoS for suffix --- trl/trainer/grpo_trainer.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 7354202258b..6665f2bbef6 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1394,9 +1394,20 @@ def _get_tool_suffix_ids(self, tool_messages): return_dict=False, **self.chat_template_kwargs, ) - if not full_ids[: len(prefix_ids)] == prefix_ids: - raise ValueError("Unexpected tokenization: the prefix IDs are not a prefix of the full IDs.") - return full_ids[len(prefix_ids) :] + + # Some chat templates (notably Qwen3/Qwen3.5) render "...<|im_end|>\\n" after an assistant/tool block. + # When we compute `suffix_ids` by slicing `full_ids`, we must align the slicing boundary to + # EOS (not EOS + newline). + prefix_ids_for_suffix = prefix_ids + if self.eos_token_id is not None and self.eos_token_id in prefix_ids: + last_eos_idx = max(i for i, tok_id in enumerate(prefix_ids) if tok_id == self.eos_token_id) + if last_eos_idx < len(prefix_ids) - 1: + prefix_ids_for_suffix = prefix_ids[: last_eos_idx + 1] + + if not full_ids[: len(prefix_ids_for_suffix)] == prefix_ids_for_suffix: + raise ValueError("Unexpected tokenization: the EOS-trimmed prefix IDs are not a prefix of the full IDs.") + + return full_ids[len(prefix_ids_for_suffix) :] def _tool_call_loop(self, prompts, prompt_ids, completion_ids, completions, logprobs, images, multimodal_fields): # Tool execution loop: execute tools, then regenerate completions with tool results appended to the prompt From bc62c918d4d229a3e27b6ecabbb77196936fce1d Mon Sep 17 00:00:00 2001 From: casinca <47400729+casinca@users.noreply.github.com> Date: Fri, 20 Mar 2026 22:52:36 +0100 Subject: [PATCH 02/11] fix(_build_messages_suffix_ids): keeping tokens after EoS for suffix (mirror _get_tool_suffix_ids) --- .../async_grpo/async_rollout_worker.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/trl/experimental/async_grpo/async_rollout_worker.py b/trl/experimental/async_grpo/async_rollout_worker.py index 480e9ef7818..bfb42d5dc4a 100644 --- a/trl/experimental/async_grpo/async_rollout_worker.py +++ b/trl/experimental/async_grpo/async_rollout_worker.py @@ -564,10 +564,20 @@ def _build_messages_suffix_ids(self, messages: list[dict[str, Any]]) -> list[int tools=self.tools, **self.chat_template_kwargs, ) - prefix_len = len(prefix_ids) - if prefix_and_messages_ids[:prefix_len] != prefix_ids: - raise ValueError("Failed to construct message suffix in token space.") - return prefix_and_messages_ids[prefix_len:] + + # Some chat templates (notably Qwen3/Qwen3.5) render "...<|im_end|>\\n" after an assistant/tool block. + # When we compute `suffix_ids` by slicing `full_ids`, we must align the slicing boundary to + # EOS (not EOS + newline). + prefix_ids_for_suffix = prefix_ids + if self.tokenizer.eos_token_id is not None and self.tokenizer.eos_token_id in prefix_ids: + last_eos_idx = max(i for i, tok_id in enumerate(prefix_ids) if tok_id == self.tokenizer.eos_token_id) + if last_eos_idx < len(prefix_ids) - 1: + prefix_ids_for_suffix = prefix_ids[: last_eos_idx + 1] + + if not prefix_and_messages_ids[: len(prefix_ids_for_suffix)] == prefix_ids_for_suffix: + raise ValueError("Unexpected tokenization: the EOS-trimmed prefix IDs are not a prefix of the full IDs.") + + return prefix_and_messages_ids[len(prefix_ids_for_suffix) :] def _execute_tool_calls( self, tool_calls: list[dict[str, Any]], tool_dict: dict[str, Callable] From 9f8d3e9f6cd742a4a6e85cf73d102658f055199e Mon Sep 17 00:00:00 2001 From: casinca <47400729+casinca@users.noreply.github.com> Date: Fri, 20 Mar 2026 22:58:56 +0100 Subject: [PATCH 03/11] added `test_get_tool_suffix_ids_eos_newline_boundary` --- tests/test_grpo_trainer.py | 62 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index bacb0a0258b..38e128cfdb2 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -38,6 +38,7 @@ from transformers.utils import is_peft_available from trl import GRPOConfig, GRPOTrainer +from trl.chat_template_utils import get_training_chat_template from trl.import_utils import is_liger_kernel_available from trl.trainer.utils import get_kbit_device_map @@ -2641,6 +2642,67 @@ def test_single_reward_model_with_single_processing_class(self): assert len(trainer.reward_processing_classes) == 1 assert trainer.reward_processing_classes[0] == single_processing_class + @pytest.mark.parametrize( + "tokenizer_name", + [ + pytest.param("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification", id="qwen3"), + pytest.param( + "trl-internal-testing/tiny-Qwen3_5ForConditionalGeneration", + id="qwen35", + marks=pytest.mark.skipif( + Version(transformers.__version__) < Version("5.0.0"), + reason="Qwen3.5 tokenizer requires transformers>=5.0.0", + ), + ), + ], + ) + def test_get_tool_suffix_ids_eos_newline_boundary(self, tokenizer_name): + """Test _get_tool_suffix_ids when the prefix ends with EOS plus trailing tokens (e.g. Qwen3/Qwen3.5).""" + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + chat_template = get_training_chat_template(tokenizer) + assert chat_template is not None + + # Call the method using a lightweight mock: _get_tool_suffix_ids only needs tokenizer + template fields. + mock_trainer = MagicMock() + mock_trainer.processing_class = tokenizer + mock_trainer.chat_template = chat_template + mock_trainer.chat_template_kwargs = {} + mock_trainer.eos_token_id = tokenizer.eos_token_id + + assert mock_trainer.eos_token_id is not None + + dummy_messages = [{"role": "user", "content": "dummy"}, {"role": "assistant", "content": "dummy"}] + tool_messages = [{"role": "tool", "name": "dummy_tool", "content": '{"temperature": 18}'}] + + prefix_ids = tokenizer.apply_chat_template( + dummy_messages, + add_generation_prompt=False, + chat_template=mock_trainer.chat_template, + return_dict=False, + **mock_trainer.chat_template_kwargs, + ) + full_ids = tokenizer.apply_chat_template( + dummy_messages + tool_messages, + add_generation_prompt=True, + chat_template=mock_trainer.chat_template, + return_dict=False, + **mock_trainer.chat_template_kwargs, + ) + + assert mock_trainer.eos_token_id is not None + assert mock_trainer.eos_token_id in prefix_ids + last_eos_idx = max(i for i, tok_id in enumerate(prefix_ids) if tok_id == mock_trainer.eos_token_id) + prefix_ids_for_suffix = ( + prefix_ids + if last_eos_idx == len(prefix_ids) - 1 + else prefix_ids[: last_eos_idx + 1] # trim trailing "<|im_end|>\\n" -> "<|im_end|>" + ) + + expected_suffix_ids = full_ids[len(prefix_ids_for_suffix) :] + actual_suffix_ids = GRPOTrainer._get_tool_suffix_ids(mock_trainer, tool_messages) + + assert actual_suffix_ids == expected_suffix_ids + @pytest.mark.slow @require_torch_accelerator From e5b786bc05ac6767eb7c0364dbee6b48df42d964 Mon Sep 17 00:00:00 2001 From: casinca <47400729+casinca@users.noreply.github.com> Date: Fri, 20 Mar 2026 23:26:05 +0100 Subject: [PATCH 04/11] style --- trl/experimental/async_grpo/async_rollout_worker.py | 2 +- trl/trainer/grpo_trainer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/trl/experimental/async_grpo/async_rollout_worker.py b/trl/experimental/async_grpo/async_rollout_worker.py index bfb42d5dc4a..a3a56461008 100644 --- a/trl/experimental/async_grpo/async_rollout_worker.py +++ b/trl/experimental/async_grpo/async_rollout_worker.py @@ -574,7 +574,7 @@ def _build_messages_suffix_ids(self, messages: list[dict[str, Any]]) -> list[int if last_eos_idx < len(prefix_ids) - 1: prefix_ids_for_suffix = prefix_ids[: last_eos_idx + 1] - if not prefix_and_messages_ids[: len(prefix_ids_for_suffix)] == prefix_ids_for_suffix: + if prefix_and_messages_ids[: len(prefix_ids_for_suffix)] != prefix_ids_for_suffix: raise ValueError("Unexpected tokenization: the EOS-trimmed prefix IDs are not a prefix of the full IDs.") return prefix_and_messages_ids[len(prefix_ids_for_suffix) :] diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 6665f2bbef6..de66160033e 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1404,7 +1404,7 @@ def _get_tool_suffix_ids(self, tool_messages): if last_eos_idx < len(prefix_ids) - 1: prefix_ids_for_suffix = prefix_ids[: last_eos_idx + 1] - if not full_ids[: len(prefix_ids_for_suffix)] == prefix_ids_for_suffix: + if full_ids[: len(prefix_ids_for_suffix)] != prefix_ids_for_suffix: raise ValueError("Unexpected tokenization: the EOS-trimmed prefix IDs are not a prefix of the full IDs.") return full_ids[len(prefix_ids_for_suffix) :] From a8e0f3ce8bf99ca63d91b7ed963526ead7c7b6fb Mon Sep 17 00:00:00 2001 From: casinca <47400729+casinca@users.noreply.github.com> Date: Sat, 21 Mar 2026 00:28:04 +0100 Subject: [PATCH 05/11] Update trl/trainer/grpo_trainer.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> --- trl/trainer/grpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index de66160033e..b12f0a42852 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1395,7 +1395,7 @@ def _get_tool_suffix_ids(self, tool_messages): **self.chat_template_kwargs, ) - # Some chat templates (notably Qwen3/Qwen3.5) render "...<|im_end|>\\n" after an assistant/tool block. + # Some chat templates (notably Qwen3/Qwen3.5) render "...<|im_end|>\n" after an assistant/tool block. # When we compute `suffix_ids` by slicing `full_ids`, we must align the slicing boundary to # EOS (not EOS + newline). prefix_ids_for_suffix = prefix_ids From 35964ab5cbcd3764d8c543bee5adb36548871f29 Mon Sep 17 00:00:00 2001 From: casinca <47400729+casinca@users.noreply.github.com> Date: Sat, 21 Mar 2026 00:29:26 +0100 Subject: [PATCH 06/11] Update trl/trainer/grpo_trainer.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> --- trl/trainer/grpo_trainer.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index b12f0a42852..dd61d223795 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1398,16 +1398,13 @@ def _get_tool_suffix_ids(self, tool_messages): # Some chat templates (notably Qwen3/Qwen3.5) render "...<|im_end|>\n" after an assistant/tool block. # When we compute `suffix_ids` by slicing `full_ids`, we must align the slicing boundary to # EOS (not EOS + newline). - prefix_ids_for_suffix = prefix_ids - if self.eos_token_id is not None and self.eos_token_id in prefix_ids: - last_eos_idx = max(i for i, tok_id in enumerate(prefix_ids) if tok_id == self.eos_token_id) - if last_eos_idx < len(prefix_ids) - 1: - prefix_ids_for_suffix = prefix_ids[: last_eos_idx + 1] + last_eos_idx = max(i for i, tok_id in enumerate(prefix_ids) if tok_id == self.eos_token_id) + prefix_ids = prefix_ids[: last_eos_idx + 1] - if full_ids[: len(prefix_ids_for_suffix)] != prefix_ids_for_suffix: + if full_ids[: len(prefix_ids)] != prefix_ids: raise ValueError("Unexpected tokenization: the EOS-trimmed prefix IDs are not a prefix of the full IDs.") - return full_ids[len(prefix_ids_for_suffix) :] + return full_ids[len(prefix_ids) :] def _tool_call_loop(self, prompts, prompt_ids, completion_ids, completions, logprobs, images, multimodal_fields): # Tool execution loop: execute tools, then regenerate completions with tool results appended to the prompt From 551509212dc998f1068f0e604a742d25fbbf6f66 Mon Sep 17 00:00:00 2001 From: casinca <47400729+casinca@users.noreply.github.com> Date: Sat, 21 Mar 2026 00:33:44 +0100 Subject: [PATCH 07/11] reverted test --- tests/test_grpo_trainer.py | 61 -------------------------------------- 1 file changed, 61 deletions(-) diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index 38e128cfdb2..f809c3daf21 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -2642,67 +2642,6 @@ def test_single_reward_model_with_single_processing_class(self): assert len(trainer.reward_processing_classes) == 1 assert trainer.reward_processing_classes[0] == single_processing_class - @pytest.mark.parametrize( - "tokenizer_name", - [ - pytest.param("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification", id="qwen3"), - pytest.param( - "trl-internal-testing/tiny-Qwen3_5ForConditionalGeneration", - id="qwen35", - marks=pytest.mark.skipif( - Version(transformers.__version__) < Version("5.0.0"), - reason="Qwen3.5 tokenizer requires transformers>=5.0.0", - ), - ), - ], - ) - def test_get_tool_suffix_ids_eos_newline_boundary(self, tokenizer_name): - """Test _get_tool_suffix_ids when the prefix ends with EOS plus trailing tokens (e.g. Qwen3/Qwen3.5).""" - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) - chat_template = get_training_chat_template(tokenizer) - assert chat_template is not None - - # Call the method using a lightweight mock: _get_tool_suffix_ids only needs tokenizer + template fields. - mock_trainer = MagicMock() - mock_trainer.processing_class = tokenizer - mock_trainer.chat_template = chat_template - mock_trainer.chat_template_kwargs = {} - mock_trainer.eos_token_id = tokenizer.eos_token_id - - assert mock_trainer.eos_token_id is not None - - dummy_messages = [{"role": "user", "content": "dummy"}, {"role": "assistant", "content": "dummy"}] - tool_messages = [{"role": "tool", "name": "dummy_tool", "content": '{"temperature": 18}'}] - - prefix_ids = tokenizer.apply_chat_template( - dummy_messages, - add_generation_prompt=False, - chat_template=mock_trainer.chat_template, - return_dict=False, - **mock_trainer.chat_template_kwargs, - ) - full_ids = tokenizer.apply_chat_template( - dummy_messages + tool_messages, - add_generation_prompt=True, - chat_template=mock_trainer.chat_template, - return_dict=False, - **mock_trainer.chat_template_kwargs, - ) - - assert mock_trainer.eos_token_id is not None - assert mock_trainer.eos_token_id in prefix_ids - last_eos_idx = max(i for i, tok_id in enumerate(prefix_ids) if tok_id == mock_trainer.eos_token_id) - prefix_ids_for_suffix = ( - prefix_ids - if last_eos_idx == len(prefix_ids) - 1 - else prefix_ids[: last_eos_idx + 1] # trim trailing "<|im_end|>\\n" -> "<|im_end|>" - ) - - expected_suffix_ids = full_ids[len(prefix_ids_for_suffix) :] - actual_suffix_ids = GRPOTrainer._get_tool_suffix_ids(mock_trainer, tool_messages) - - assert actual_suffix_ids == expected_suffix_ids - @pytest.mark.slow @require_torch_accelerator From a54b08e4721ef613975b8c73ec459b31b280b491 Mon Sep 17 00:00:00 2001 From: casinca <47400729+casinca@users.noreply.github.com> Date: Sat, 21 Mar 2026 00:41:41 +0100 Subject: [PATCH 08/11] mirror simplification to async grpo --- trl/experimental/async_grpo/async_rollout_worker.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/trl/experimental/async_grpo/async_rollout_worker.py b/trl/experimental/async_grpo/async_rollout_worker.py index 777f181d093..c94bc1a5ad6 100644 --- a/trl/experimental/async_grpo/async_rollout_worker.py +++ b/trl/experimental/async_grpo/async_rollout_worker.py @@ -577,16 +577,13 @@ def _build_messages_suffix_ids(self, messages: list[dict[str, Any]]) -> list[int # Some chat templates (notably Qwen3/Qwen3.5) render "...<|im_end|>\\n" after an assistant/tool block. # When we compute `suffix_ids` by slicing `full_ids`, we must align the slicing boundary to # EOS (not EOS + newline). - prefix_ids_for_suffix = prefix_ids - if self.tokenizer.eos_token_id is not None and self.tokenizer.eos_token_id in prefix_ids: - last_eos_idx = max(i for i, tok_id in enumerate(prefix_ids) if tok_id == self.tokenizer.eos_token_id) - if last_eos_idx < len(prefix_ids) - 1: - prefix_ids_for_suffix = prefix_ids[: last_eos_idx + 1] + last_eos_idx = max(i for i, tok_id in enumerate(prefix_ids) if tok_id == self.tokenizer.eos_token_id) + prefix_ids = prefix_ids[: last_eos_idx + 1] - if prefix_and_messages_ids[: len(prefix_ids_for_suffix)] != prefix_ids_for_suffix: + if prefix_and_messages_ids[: len(prefix_ids)] != prefix_ids: raise ValueError("Unexpected tokenization: the EOS-trimmed prefix IDs are not a prefix of the full IDs.") - return prefix_and_messages_ids[len(prefix_ids_for_suffix) :] + return prefix_and_messages_ids[len(prefix_ids) :] def _execute_tool_calls( self, tool_calls: list[dict[str, Any]], tool_dict: dict[str, Callable] From b202f6dbef28d2995305bf09c6a829c7c8f213ee Mon Sep 17 00:00:00 2001 From: casinca <47400729+casinca@users.noreply.github.com> Date: Sat, 21 Mar 2026 00:49:34 +0100 Subject: [PATCH 09/11] removed unused import --- tests/test_grpo_trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index f809c3daf21..bacb0a0258b 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -38,7 +38,6 @@ from transformers.utils import is_peft_available from trl import GRPOConfig, GRPOTrainer -from trl.chat_template_utils import get_training_chat_template from trl.import_utils import is_liger_kernel_available from trl.trainer.utils import get_kbit_device_map From b4adbb273bb6957c9068b8d52ce28c1fba197d3a Mon Sep 17 00:00:00 2001 From: casinca <47400729+casinca@users.noreply.github.com> Date: Sat, 21 Mar 2026 00:51:33 +0100 Subject: [PATCH 10/11] typo --- trl/experimental/async_grpo/async_rollout_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/experimental/async_grpo/async_rollout_worker.py b/trl/experimental/async_grpo/async_rollout_worker.py index c94bc1a5ad6..fa2ab74447c 100644 --- a/trl/experimental/async_grpo/async_rollout_worker.py +++ b/trl/experimental/async_grpo/async_rollout_worker.py @@ -574,7 +574,7 @@ def _build_messages_suffix_ids(self, messages: list[dict[str, Any]]) -> list[int **self.chat_template_kwargs, ) - # Some chat templates (notably Qwen3/Qwen3.5) render "...<|im_end|>\\n" after an assistant/tool block. + # Some chat templates (notably Qwen3/Qwen3.5) render "...<|im_end|>\n" after an assistant/tool block. # When we compute `suffix_ids` by slicing `full_ids`, we must align the slicing boundary to # EOS (not EOS + newline). last_eos_idx = max(i for i, tok_id in enumerate(prefix_ids) if tok_id == self.tokenizer.eos_token_id) From f973ad1181be675fb972939d36041e6854e37ec5 Mon Sep 17 00:00:00 2001 From: casinca <47400729+casinca@users.noreply.github.com> Date: Sat, 21 Mar 2026 01:01:22 +0100 Subject: [PATCH 11/11] typo: aligned comment --- trl/experimental/async_grpo/async_rollout_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/experimental/async_grpo/async_rollout_worker.py b/trl/experimental/async_grpo/async_rollout_worker.py index fa2ab74447c..ac814b40ed7 100644 --- a/trl/experimental/async_grpo/async_rollout_worker.py +++ b/trl/experimental/async_grpo/async_rollout_worker.py @@ -575,7 +575,7 @@ def _build_messages_suffix_ids(self, messages: list[dict[str, Any]]) -> list[int ) # Some chat templates (notably Qwen3/Qwen3.5) render "...<|im_end|>\n" after an assistant/tool block. - # When we compute `suffix_ids` by slicing `full_ids`, we must align the slicing boundary to + # When we compute `suffix_ids` by slicing `prefix_and_messages_ids`, we must align the slicing boundary to # EOS (not EOS + newline). last_eos_idx = max(i for i, tok_id in enumerate(prefix_ids) if tok_id == self.tokenizer.eos_token_id) prefix_ids = prefix_ids[: last_eos_idx + 1]