Update handling EOS token id detection#1925
Conversation
|
In search_cuda.cpp line 173, we are checking for eos do we need to set hit_eos there or no? onnxruntime-genai/src/cuda/search_cuda.cpp Line 173 in 4a13b80 |
|
Should we go ahead and close this PR now that the issue can be addressed through the existing API? |
There was a problem hiding this comment.
Pull request overview
This PR refines EOS/max-length handling in generation loops by relying on Generator.GetNextTokens() and a new TokenCount API, and surfaces search-parameter query APIs across languages so callers can distinguish termination conditions. It also adds tokenizer BOS/EOS/PAD accessors for Objective-C and refreshes examples and docs to reflect the new recommended generation patterns and model support.
Changes:
- Introduces
Generator.TokenCountandGeneratorParams.GetSearchNumber/GetSearchBoolat the core/C-API level and threads them through Python, C#, Java, Objective-C, with corresponding tests. - Standardizes generation loops in tests and examples to use
while !IsDone()plusGetNextTokens()for streaming, and updates platform tests to assert consistency between token count and sequence lengths. - Adds missing tokenizer APIs (BOS/EOS/PAD token ids) for Objective-C and updates the README support matrix and example usage/docs, while removing the deprecated graph-capture parameter API from public headers.
Reviewed changes
Copilot reviewed 47 out of 48 changed files in this pull request and generated 9 comments.
Show a summary per file
| File | Description |
|---|---|
| test/python/test_onnxruntime_genai_e2e.py | Simplifies the Python e2e test generation loop to while not generator.is_done() without an inner break, reflecting the new loop guidance. |
| test/python/test_onnxruntime_genai_api.py | Adds assertions around get_search_options() and generator.token_count() and converts multiple generation loops to while not generator.is_done(). |
| test/platform/apple/apple_package_test/macos_package_testUITests/macos_package_uitest_cpp_api.mm | Uses C++ GetSearchNumber/GetSearchBool and TokenCount in the macOS UI test and updates the loop to while (!generator->IsDone()), plus post-generation TokenCount/sequence-length consistency checks. |
| test/platform/apple/apple_package_test/ios_package_testUITests/ios_package_uitest_cpp_api.mm | Mirrors the macOS C++ API UI test changes for iOS, validating search params and TokenCount before and after generation. |
| test/model_tests.cpp | Replaces numerous while (true) + if (IsDone()) break; loops with while (!generator->IsDone()) in C++ model tests. |
| test/csharp/TestOnnxRuntimeGenAIAPI.cs | Updates C# tests to use GeneratorParams.GetSearchNumber/GetSearchBool and Generator.TokenCount(), and simplifies generation loops to while (!generator.IsDone()). |
| test/c_api_tests.cpp | Converts C API tests to while (!generator->IsDone()) and adds TokenCount assertions in the EOS/PAD test, relying on the new C API getter. |
| src/search.h | Extends the Search interface with a virtual ResetDone() method and declares overrides in CPU/CUDA search implementations. |
| src/search.cpp | Implements Search_Cpu::ResetDone and GreedySearch_Cpu::ResetDone, and reuses ResetDone() from AppendTokens and RewindTo to centralize done/eos_seen reset logic. |
| src/python/python.cpp | Removes the deprecated try_graph_capture_with_max_batch_size, adds PyGeneratorParams.get_search_options(), and binds generator.token_count() in the Python wrapper. |
| src/ort_genai_c.h | Removes the deprecated OgaGeneratorParamsTryGraphCaptureWithMaxBatchSize, documents and declares C APIs for getting/setting search numbers/bools, OgaGenerator_TokenCount, runtime options, and tokenizer BOS/EOS/PAD accessors. |
| src/ort_genai_c.cpp | Implements the new C APIs for getting search numbers/bools and OgaGenerator_TokenCount, and wires tokenizer BOS/EOS/PAD getters to the underlying C++ APIs. |
| src/ort_genai.h | Adds C++ RAII wrappers GeneratorParams::GetSearchNumber/GetSearchBool, OgaGenerator::TokenCount, and a non-span GetNextTokens() variant while removing the deprecated graph-capture helper. |
| src/objectivec/oga_tokenizer.mm | Implements Objective-C methods to read BOS/EOS/PAD token ids via the C++ tokenizer and return them as scalar/NSArray values. |
| src/objectivec/oga_generator_params.mm | Adds Objective-C getSearchNumber: and getSearchBool: that call into OgaGeneratorParams::GetSearchNumber/GetSearchBool. |
| src/objectivec/oga_generator.mm | Introduces Objective-C tokenCount: that forwards to C++ OgaGenerator::TokenCount. |
| src/objectivec/include/ort_genai_objc.h | Updates the public Objective-C header to expose BOS/EOS/PAD tokenizer accessors, generator-param search getters, and generator tokenCount: with documentation. |
| src/objectivec/error_utils.h | Adds new Objective-C helper macros for catching C++ exceptions when returning double and int, but currently defines them with malformed macro signatures (compile-time bug). |
| src/java/src/test/java/ai/onnxruntime/genai/GeneratorParamsTest.java | Clarifies a Java test comment to indicate it validates setting a valid search option. |
| src/java/src/test/java/ai/onnxruntime/genai/GenerationTest.java | Uses getSearchNumber, getSearchBool, and generator.tokenCount() in the Java generation test alongside the updated while (!generator.isDone()) loop. |
| src/java/src/main/native/ai_onnxruntime_genai_GeneratorParams.cpp | Implements JNI bridges for GeneratorParams.getSearchNumber and getSearchBool using the new C API getters. |
| src/java/src/main/native/ai_onnxruntime_genai_Generator.cpp | Adds a JNI method tokenCount that forwards to OgaGenerator_TokenCount. |
| src/java/src/main/java/ai/onnxruntime/genai/GeneratorParams.java | Exposes getSearchNumber and getSearchBool in the Java API and adds the corresponding native method declarations. |
| src/java/src/main/java/ai/onnxruntime/genai/Generator.java | Exposes tokenCount() in the Java Generator API and wires it to the new JNI-native method. |
| src/generators.h | Adds GeneratorParams::GetSearchNumber/GetSearchBool and Generator::TokenCount() to the core C++ generator interfaces. |
| src/generators.cpp | Implements search-option querying with explicit name dispatch and error on invalid names, and Generator::TokenCount() via search_->GetSequenceLength(). |
| src/cuda/search_cuda.h | Declares Search_Cuda::ResetDone() to match the extended Search interface. |
| src/cuda/search_cuda.cpp | Centralizes CUDA done/eos_seen reset logic in Search_Cuda::ResetDone() and reuses it from constructors, AppendTokens, and RewindTo. |
| src/csharp/NativeMethods.cs | Adds P/Invoke signatures for OgaGeneratorParamsGetSearchNumber/GetSearchBool and OgaGenerator_TokenCount. |
| src/csharp/GeneratorParams.cs | Wraps the new C# interop getters as GetSearchNumber/GetSearchBool on GeneratorParams and removes the deprecated graph-capture stub. |
| src/csharp/Generator.cs | Exposes TokenCount() on the C# Generator class, backed by the new native function. |
| src/config.h | Fixes a minor spacing typo in the early_stopping comment in Config::Search. |
| examples/python/phi4-mm.py | Updates the multimodal Python example’s generation loop to while not generator.is_done() with get_next_tokens() for streaming. |
| examples/python/phi3-qa.py | Aligns the QA Python example with the new loop pattern and keeps the timing logic around generate_next_token. |
| examples/python/model-vision.py | Uses while not generator.is_done() + get_next_tokens() in the vision example instead of while True with an inner is_done() break. |
| examples/python/model-qa.py | Similarly updates the text QA example loop to rely on is_done() only in the loop condition. |
| examples/python/model-generate.py | Simplifies the batch generation loop to while not generator.is_done() without an inner if is_done(): break. |
| examples/python/model-chat.py | Updates the chat example to use the new loop pattern while retaining timing and streaming decode via get_next_tokens(). |
| examples/python/awq-quantized-model.py | Applies the same loop update to the AWQ-quantized model example. |
| examples/csharp/HelloPhi4MM/Program.cs | Updates the C# multimodal example to while (!generator.IsDone()) with streaming decode via GetNextTokens()[0]. |
| examples/csharp/HelloPhi3V/Program.cs | Mirrors the multimodal loop update for the Phi3V C# example. |
| examples/csharp/HelloPhi/Program.cs | Converts several Phi C# examples (batch and streaming/chat) to the new while (!generator.IsDone()) pattern. |
| examples/c/src/whisper.cpp | Adjusts the Whisper C++ example’s generation loop to use while (!generator->IsDone()). |
| examples/c/src/phi4-mm.cpp | Switches the multimodal C++ example to while (!generator->IsDone()) and uses GetNextTokens()[0] instead of manually indexing GetSequenceData. |
| examples/c/src/model_vision.cpp | Same as phi4-mm: uses while (!generator->IsDone()) and GetNextTokens()[0] for last token retrieval. |
| examples/c/src/model_qa.cpp | Updates the QA C++ example loop and uses GetNextTokens()[0] for streaming outputs. |
| examples/c/src/model_chat.cpp | Updates the chat C++ example to use while (!generator->IsDone()) and GetNextTokens()[0] for token streaming. |
| README.md | Refreshes the support matrix (model names/lines), updates the recommended Python generation loop, revises version checkout instructions, and replaces the old “Breaking API changes” section with nightly install guidance. |
Version 2
Description
This PR updates the examples to show how EOS token id detection is handled with ONNX Runtime GenAI when generating tokens. With the addition of the C# binding for GetNextTokens(), all of the published examples now cover the cases listed below in version 1 of this PR. Previously, the earlier PR mentioned different variations of the generation loop and all of the variations had an issue.
This PR also introduces new APIs for tracking token count and querying the generator params:
Generator.TokenCountParams.GetSearchNumberParams.GetSearchBoolAdditionally, this PR adds some missing Tokenizer APIs for Objective-C.
Tokenizer.GetBosTokenIdTokenizer.GetEosTokenIdsTokenizer.GetPadTokenIdMotivation and Context
This PR is a follow-up to the issue fixed in an earlier PR. These APIs can be used by users to distinguish between the cases that
Generator.IsDone()covers.For example:
Version 1
Description
This PR updates how EOS token id detection is handled with ONNX Runtime GenAI when generating tokens. A new API called
Generator.HitEOS()is introduced to detect whether an EOS token id has been generated. Another API calledGenerator.HitMaxLength()is also introduced to detect whether the max length has been hit before the generation loop has completed.Motivation and Context
This PR is a follow-up to the issue fixed in an earlier PR. The earlier PR mentions different variations of the generation loop but all of the variations have an issue.
There are two scenarios for terminating the generation loop: 1) hitting the EOS token id and completing the generation loop or 2) hitting the max length before the generation loop has completed. However, none of the variations adequately cover the two scenarios for terminating the generation loop.
1. Original Generation Loop
Consider scenario 1 with this loop. After
GenerateToken()produces the EOS token id,GetLastToken()will attempt to retrieve that token. However, ORT GenAI does not append the EOS token id to the list of sequences returned to the user (see the earlier PR for why). Instead, the second-to-last token will still be the last token in the list of sequences. Thus,GetLastToken()andPrintLastToken()will retrieve and again print the last token that the user saw.2. Return Early Generation Loop
Consider scenario 2 with this loop. After
GenerateToken()produces a token and the max length has been reached, the generator's state is marked as done. ThenIsDone()will be true and the newest token won't be retrieved and printed since the loop is exited early.3. Infinite Generation Loop
Consider scenario 2 with this loop. The same issue as the prior loop still applies.
GenerateToken()will generate all of the tokens but once the max length is hit,IsDone()is true and the last token won't be retrieved and printed.Conclusion
The reason that none of these generation loop variants work is because
IsDone()currently covers both scenarios in one API and does not distinguish between them. One check needs to be in place in the condition of the while loop so that the loop continues, and another check needs to be after token generation to decide whether retrieving the last token should be done or not.Solution
To fix this, a new API called
Generator.HitEOS()is introduced. It returnstruewhen the EOS token id is generated. The generation loop should be modified to the following.If scenario 1 occurs in this loop,
HitEOS()istrueand the generation loop will exit early. If scenario 2 occurs in this loop,HitEOS()isfalsewhen the max length is reached. The last generated token can still be retrieved and printed. Then because the generator's state is done,IsDone()istrueand the generation loop ends.Here is a full end-to-end example demonstrating its usage.
Scenario 1
Before with loop version 1:
After with
generator.hit_eos():Scenario 2
Before with loop version 2:
After with
generator.hit_eos():