From 7df4aa75a330428acafe52f6d1f686dbe6265a18 Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Wed, 17 Dec 2025 19:52:08 -0800 Subject: [PATCH 01/38] Add IsMaxLength API --- src/cuda/search_cuda.cpp | 24 +++++++++++++++--------- src/cuda/search_cuda.h | 7 +++++++ src/generators.cpp | 4 ++++ src/generators.h | 1 + src/search.cpp | 23 ++++++++++++++++------- src/search.h | 6 ++++++ 6 files changed, 49 insertions(+), 16 deletions(-) diff --git a/src/cuda/search_cuda.cpp b/src/cuda/search_cuda.cpp index ad5a7c6c5a..34d74a36ee 100644 --- a/src/cuda/search_cuda.cpp +++ b/src/cuda/search_cuda.cpp @@ -25,14 +25,13 @@ Search_Cuda::Search_Cuda(const GeneratorParams& params) sequence_lengths_ = params.p_device->Allocate(batch_beam_size); eos_seen_buffer_ = CudaMallocArray(batch_beam_size, &eos_seen_); - cudaMemsetAsync(eos_seen_.data(), 0, eos_seen_.size_bytes(), GetStream()); + done_cpu_ = CudaMallocHostArray(1); eos_token_ids_ = params.p_device->Allocate(params.config.model.eos_token_id.size()); copy(std::span{params.config.model.eos_token_id}, eos_token_ids_.CpuSpan()); eos_token_ids_.CopyCpuToDevice(); - done_cpu_ = CudaMallocHostArray(1); - *done_cpu_ = false; + ResetDone(); } GreedySearch_Cuda::GreedySearch_Cuda(const GeneratorParams& params) @@ -75,6 +74,13 @@ BeamSearch_Cuda::BeamSearch_Cuda(const GeneratorParams& params) BeamSearch_Cuda::~BeamSearch_Cuda() = default; +void Search_Cuda::ResetDone() { + *done_cpu_ = false; + *hit_eos_cpu_ = false; + *hit_max_length_cpu_ = false; + cudaMemsetAsync(eos_seen_.data(), 0, eos_seen_.size_bytes(), GetStream()); +} + DeviceSpan Search_Cuda::GetLogits() const { return next_token_scores_; } @@ -176,6 +182,7 @@ void GreedySearch_Cuda::SampleTopKTopP(int k, float p, float temperature) { if (GetLogItems().enabled && GetLogItems().hit_max_length) Log("hit_max_length", "greedy cuda hit"); *done_cpu_ = true; + *hit_max_length_cpu_ = true; } } @@ -186,6 +193,7 @@ bool BeamSearch_Cuda::IsDone() const { if (sequences_.GetSequenceLength() == params_->search.max_length) { if (GetLogItems().enabled && GetLogItems().hit_max_length) Log("hit_max_length", "beam cuda hit"); + *hit_max_length_cpu_ = true; return true; } return false; @@ -221,8 +229,7 @@ std::span Search_Cuda::GetScores() { // Set user input tokens (batch_beam_size, sequence_length) void GreedySearch_Cuda::AppendTokens(DeviceSpan& next_tokens) { - cudaMemsetAsync(eos_seen_.data(), 0, eos_seen_.size_bytes(), GetStream()); - *done_cpu_ = false; + ResetDone(); auto next_tokens_gpu = next_tokens.Span(); cuda::Launch_AppendNextTokensToSequences(next_tokens_gpu, sequences_.GetSequences().Span(), params_->BatchBeamSize(), sequences_.GetSequenceLength(), sequences_.max_length_, GetStream()); @@ -232,11 +239,11 @@ void GreedySearch_Cuda::AppendTokens(DeviceSpan& next_tokens) { if (GetLogItems().enabled && GetLogItems().hit_max_length) Log("hit_max_length", "greedy cuda hit"); *done_cpu_ = true; + *hit_max_length_cpu_ = true; return; } - cudaMemsetAsync(eos_seen_.data(), 0, eos_seen_.size_bytes(), GetStream()); - *done_cpu_ = false; + ResetDone(); } void BeamSearch_Cuda::AppendTokens(DeviceSpan& next_tokens) { @@ -248,8 +255,7 @@ void BeamSearch_Cuda::AppendTokens(DeviceSpan& next_tokens) { } void GreedySearch_Cuda::RewindTo(size_t index) { - cudaMemsetAsync(eos_seen_.data(), 0, eos_seen_.size_bytes(), GetStream()); - *done_cpu_ = false; + ResetDone(); if (index > 0) cuda::Launch_GetLastTokens(next_tokens_.data(), sequences_.GetSequences().Span().data(), static_cast(params_->BatchBeamSize()), static_cast(index), sequences_.max_length_, GetStream()); else diff --git a/src/cuda/search_cuda.h b/src/cuda/search_cuda.h index 6d0f03f73d..6ffc34c15b 100644 --- a/src/cuda/search_cuda.h +++ b/src/cuda/search_cuda.h @@ -19,6 +19,11 @@ struct Search_Cuda : Search { cudaStreamSynchronize(GetStream()); return *done_cpu_; } // TODO: Use an event + bool IsMaxLength() const { + cudaStreamSynchronize(GetStream()); + return *hit_max_length_cpu_; + } + void ResetDone() override; DeviceSpan GetLogits() const override; void SetLogits(DeviceSpan logits) override; @@ -39,6 +44,8 @@ struct Search_Cuda : Search { DeviceSpan next_token_scores_; // shape (beam_size*batch_size, vocab_size) cuda_host_unique_ptr done_cpu_; + cuda_host_unique_ptr hit_eos_cpu_; + cuda_host_unique_ptr hit_max_length_cpu_; }; struct GreedySearch_Cuda : Search_Cuda { diff --git a/src/generators.cpp b/src/generators.cpp index d19751217c..14504d658a 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -472,6 +472,10 @@ bool Generator::IsDone() { return is_done; } +bool Generator::IsMaxLength() { + return search_->IsMaxLength(); +} + bool Generator::IsSessionTerminated() const { return state_->session_terminated_; } diff --git a/src/generators.h b/src/generators.h index 1d643e198b..6ae02f6cd1 100644 --- a/src/generators.h +++ b/src/generators.h @@ -95,6 +95,7 @@ struct Generator : LeakChecked { Generator(const Model& model, const GeneratorParams& params); bool IsDone(); + bool IsMaxLength(); void AppendTokens(cpu_span input_ids); void GenerateNextToken(); void RewindToLength(size_t new_length); // Rewind state to new_length diff --git a/src/search.cpp b/src/search.cpp index 1d69c30115..92af7b96e7 100644 --- a/src/search.cpp +++ b/src/search.cpp @@ -49,6 +49,15 @@ BeamSearch_Cpu::BeamSearch_Cpu(const GeneratorParams& params) BeamSearch_Cpu::~BeamSearch_Cpu() = default; +void Search_Cpu::ResetDone() { + // Reset done count/state + done_ = false; + hit_eos_ = false; + hit_max_length_ = false; + not_done_count_ = params_->search.batch_size; + memset(eos_seen_.data(), 0, eos_seen_.size_bytes()); +} + DeviceSpan Search_Cpu::GetLogits() const { return next_token_scores_; } @@ -298,6 +307,7 @@ void GreedySearch_Cpu::SetNextToken(size_t batch_id, int32_t token) { Log("hit_eos", "EOS seen on batch " + std::to_string(batch_id)); if (--not_done_count_ == 0) { done_ = true; + hit_eos_ = true; } } } @@ -319,6 +329,7 @@ void GreedySearch_Cpu::AppendNextTokensToSequences() { if (g_log.enabled && g_log.hit_max_length) Log("hit_max_length", "greedy cpu hit"); done_ = true; + hit_max_length_ = true; } } @@ -333,16 +344,13 @@ void GreedySearch_Cpu::AppendTokens(DeviceSpan& next_tokens) { } AppendNextTokensToSequences(); } - // Reset done count/state - done_ = false; - not_done_count_ = params_->search.batch_size; - memset(eos_seen_.data(), 0, eos_seen_.size_bytes()); + + ResetDone(); } void GreedySearch_Cpu::RewindTo(size_t index) { - done_ = false; - not_done_count_ = params_->search.batch_size; - memset(eos_seen_.data(), 0, eos_seen_.size_bytes()); + ResetDone(); + // Set next tokens to the last tokens in the sequence if (index > 0) { for (int i = 0; i < params_->BatchBeamSize(); i++) { @@ -407,6 +415,7 @@ void BeamSearch_Cpu::AppendNextTokensToSequences() { if (g_log.enabled && g_log.hit_max_length) Log("hit_max_length", "beam cpu hit"); done_ = true; + hit_max_length_ = true; } } diff --git a/src/search.h b/src/search.h index 9d456368c1..4c47d66a0a 100644 --- a/src/search.h +++ b/src/search.h @@ -19,6 +19,8 @@ struct Search : LeakChecked { virtual DeviceSpan GetLogits() const = 0; virtual void SetLogits(DeviceSpan logits) = 0; virtual bool IsDone() const = 0; + virtual bool IsMaxLength() const = 0; + virtual void ResetDone() = 0; virtual void SelectTop() = 0; virtual void SampleTopP(float /*p*/, float /*temperature*/) { assert(false); } @@ -44,6 +46,8 @@ struct Search_Cpu : Search { DeviceSpan GetSequenceLengths() override { return sequence_lengths_; } bool IsDone() const override { return done_; } + bool IsMaxLength() const override { return hit_max_length_; } + void ResetDone() override; DeviceSpan GetLogits() const override; void SetLogits(DeviceSpan logits) override; @@ -61,6 +65,8 @@ struct Search_Cpu : Search { DeviceSpan next_token_scores_; // shape (beam_size*batch_size, vocab_size) bool done_{}; + bool hit_eos_{}; + bool hit_max_length_{}; }; struct GreedySearch_Cpu : Search_Cpu { From 9d41ed96dd39e01fd09dcc0f25a8c2bef64c1d50 Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Thu, 18 Dec 2025 04:50:49 +0000 Subject: [PATCH 02/38] Add HitEOS and HitMaxLength APIs --- src/csharp/Generator.cs | 10 +++++++ src/csharp/NativeMethods.cs | 8 ++++++ src/cuda/search_cuda.h | 6 ++++- src/generators.cpp | 8 ++++-- src/generators.h | 3 ++- .../java/ai/onnxruntime/genai/Generator.java | 26 +++++++++++++++++++ .../native/ai_onnxruntime_genai_Generator.cpp | 10 +++++++ src/objectivec/include/ort_genai_objc.h | 14 ++++++++++ src/objectivec/oga_generator.mm | 14 ++++++++++ src/ort_genai.h | 8 ++++++ src/ort_genai_c.cpp | 8 ++++++ src/ort_genai_c.h | 15 +++++++++++ src/python/python.cpp | 8 ++++++ src/search.h | 6 +++-- 14 files changed, 138 insertions(+), 6 deletions(-) diff --git a/src/csharp/Generator.cs b/src/csharp/Generator.cs index 0c7eb31d81..68282ba5e5 100644 --- a/src/csharp/Generator.cs +++ b/src/csharp/Generator.cs @@ -20,6 +20,16 @@ public bool IsDone() return NativeMethods.OgaGenerator_IsDone(_generatorHandle) != 0; } + public bool HitEOS() + { + return NativeMethods.OgaGenerator_HitEOS(_generatorHandle) != 0; + } + + public bool HitMaxLength() + { + return NativeMethods.OgaGenerator_HitMaxLength(_generatorHandle) != 0; + } + public void SetModelInput(string name, Tensor value) { Result.VerifySuccess(NativeMethods.OgaGenerator_SetModelInput(_generatorHandle, StringUtils.ToUtf8(name), value.Handle)); diff --git a/src/csharp/NativeMethods.cs b/src/csharp/NativeMethods.cs index ce505b6ec1..ca9c5fb261 100644 --- a/src/csharp/NativeMethods.cs +++ b/src/csharp/NativeMethods.cs @@ -125,6 +125,14 @@ internal class NativeLib [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)] public static extern byte OgaGenerator_IsDone(IntPtr /* const OgaGenerator* */ generator); + // This function is used to check if the generator has hit the EOS token id after generating all sequences. + [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)] + public static extern byte OgaGenerator_HitEOS(IntPtr /* const OgaGenerator* */ generator); + + // This function is used to check if the generator has hit the max length after generating all sequences. + [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)] + public static extern byte OgaGenerator_HitMaxLength(IntPtr /* const OgaGenerator* */ generator); + [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)] public static extern IntPtr /* OgaResult* */ OgaGenerator_GetNextTokens(IntPtr /* const OgaGenerator* */ generator, out IntPtr /* const int32_t** */ outTokenIds, diff --git a/src/cuda/search_cuda.h b/src/cuda/search_cuda.h index 6ffc34c15b..a2cf44b95f 100644 --- a/src/cuda/search_cuda.h +++ b/src/cuda/search_cuda.h @@ -19,7 +19,11 @@ struct Search_Cuda : Search { cudaStreamSynchronize(GetStream()); return *done_cpu_; } // TODO: Use an event - bool IsMaxLength() const { + bool HitEOS() const { + cudaStreamSynchronize(GetStream()); + return *hit_eos_cpu_; + } + bool HitMaxLength() const { cudaStreamSynchronize(GetStream()); return *hit_max_length_cpu_; } diff --git a/src/generators.cpp b/src/generators.cpp index 14504d658a..80b73809da 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -472,8 +472,12 @@ bool Generator::IsDone() { return is_done; } -bool Generator::IsMaxLength() { - return search_->IsMaxLength(); +bool Generator::HitEOS() { + return search_->HitEOS(); +} + +bool Generator::HitMaxLength() { + return search_->HitMaxLength(); } bool Generator::IsSessionTerminated() const { diff --git a/src/generators.h b/src/generators.h index 6ae02f6cd1..a6bacb2573 100644 --- a/src/generators.h +++ b/src/generators.h @@ -95,7 +95,8 @@ struct Generator : LeakChecked { Generator(const Model& model, const GeneratorParams& params); bool IsDone(); - bool IsMaxLength(); + bool HitEOS() const; + bool HitMaxLength() const; void AppendTokens(cpu_span input_ids); void GenerateNextToken(); void RewindToLength(size_t new_length); // Rewind state to new_length diff --git a/src/java/src/main/java/ai/onnxruntime/genai/Generator.java b/src/java/src/main/java/ai/onnxruntime/genai/Generator.java index 66489ed8df..78213860d7 100644 --- a/src/java/src/main/java/ai/onnxruntime/genai/Generator.java +++ b/src/java/src/main/java/ai/onnxruntime/genai/Generator.java @@ -61,6 +61,32 @@ public boolean isDone() { return isDone(nativeHandle); } + /** + * Checks if the generation process ended because an EOS token id was hit. + * + * @return true if the EOS token was hit, false otherwise. + */ + public boolean hitEOS() { + if (nativeHandle == 0) { + throw new IllegalStateException("Instance has been freed and is invalid"); + } + + return hitEOS(nativeHandle); + } + + /** + * Checks if the generation process ended because the maximum length was hit. + * + * @return true if the maximum length was hit, false otherwise. + */ + public boolean hitMaxLength() { + if (nativeHandle == 0) { + throw new IllegalStateException("Instance has been freed and is invalid"); + } + + return hitMaxLength(nativeHandle); + } + /** * Add a Tensor as a model input. * diff --git a/src/java/src/main/native/ai_onnxruntime_genai_Generator.cpp b/src/java/src/main/native/ai_onnxruntime_genai_Generator.cpp index c75e397917..5fc72083e2 100644 --- a/src/java/src/main/native/ai_onnxruntime_genai_Generator.cpp +++ b/src/java/src/main/native/ai_onnxruntime_genai_Generator.cpp @@ -72,6 +72,16 @@ Java_ai_onnxruntime_genai_Generator_isDone(JNIEnv* env, jobject thiz, jlong nati return OgaGenerator_IsDone(reinterpret_cast(native_handle)); } +JNIEXPORT jboolean JNICALL +Java_ai_onnxruntime_genai_Generator_hitEOS(JNIEnv* env, jobject thiz, jlong native_handle) { + return OgaGenerator_HitEOS(reinterpret_cast(native_handle)); +} + +JNIEXPORT jboolean JNICALL +Java_ai_onnxruntime_genai_Generator_hitMaxLength(JNIEnv* env, jobject thiz, jlong native_handle) { + return OgaGenerator_HitMaxLength(reinterpret_cast(native_handle)); +} + JNIEXPORT void JNICALL Java_ai_onnxruntime_genai_Generator_rewindTo(JNIEnv* env, jobject thiz, jlong native_handle, jlong length) { ThrowIfError(env, OgaGenerator_RewindTo(reinterpret_cast(native_handle), length)); diff --git a/src/objectivec/include/ort_genai_objc.h b/src/objectivec/include/ort_genai_objc.h index 9ff082ffc6..efe8404255 100644 --- a/src/objectivec/include/ort_genai_objc.h +++ b/src/objectivec/include/ort_genai_objc.h @@ -234,6 +234,20 @@ typedef NS_ENUM(NSInteger, OGAElementType) { */ - (BOOL)isDoneWithError:(NSError**)error __attribute__((swift_error(nonnull_error))); +/** + * Whether generation ended because an EOS token id was hit. + * @param error Optional error information set if an error occurs. + * @return The result, or false if an error occurs. + */ +- (BOOL)hitEOSWithError:(NSError**)error __attribute__((swift_error(nonnull_error))); + +/** + * Whether generation ended because the maximum length was hit. + * @param error Optional error information set if an error occurs. + * @return The result, or false if an error occurs. + */ +- (BOOL)hitMaxLengthWithError:(NSError**)error __attribute__((swift_error(nonnull_error))); + /** * Set input with NamedTensors type. * @param namedTensors The named tensors. diff --git a/src/objectivec/oga_generator.mm b/src/objectivec/oga_generator.mm index b5ca0e941c..b4874594dc 100644 --- a/src/objectivec/oga_generator.mm +++ b/src/objectivec/oga_generator.mm @@ -30,6 +30,20 @@ - (BOOL)isDoneWithError:(NSError**)error { OGA_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) } +- (BOOL)hitEOSWithError:(NSError**)error { + try { + return _generator->HitEOS(); + } + OGA_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) +} + +- (BOOL)hitMaxLengthWithError:(NSError**)error { + try { + return _generator->HitMaxLength(); + } + OGA_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) +} + - (BOOL)setInputs:(OGANamedTensors*)namedTensors error:(NSError**)error { try { _generator->SetInputs([namedTensors CXXAPIOgaNamedTensors]); diff --git a/src/ort_genai.h b/src/ort_genai.h index dc38a1b967..b0c8ce379d 100644 --- a/src/ort_genai.h +++ b/src/ort_genai.h @@ -448,6 +448,14 @@ struct OgaGenerator : OgaAbstract { return OgaGenerator_IsDone(this); } + bool HitEOS() { + return OgaGenerator_HitEOS(this); + } + + bool HitMaxLength() { + return OgaGenerator_HitMaxLength(this); + } + void SetModelInput(const char* name, OgaTensor& tensor) { OgaCheckResult(OgaGenerator_SetModelInput(this, name, &tensor)); } diff --git a/src/ort_genai_c.cpp b/src/ort_genai_c.cpp index e18b424b0b..bebfb0c1db 100644 --- a/src/ort_genai_c.cpp +++ b/src/ort_genai_c.cpp @@ -413,6 +413,14 @@ bool OGA_API_CALL OgaGenerator_IsDone(OgaGenerator* generator) { return generator->IsDone(); } +bool OGA_API_CALL OgaGenerator_HitEOS(OgaGenerator* generator) { + return generator->HitEOS(); +} + +bool OGA_API_CALL OgaGenerator_HitMaxLength(OgaGenerator* generator) { + return generator->HitMaxLength(); +} + bool OGA_API_CALL OgaGenerator_IsSessionTerminated(const OgaGenerator* generator) { return generator->IsSessionTerminated(); } diff --git a/src/ort_genai_c.h b/src/ort_genai_c.h index 2b6e4156ed..e08a3891c2 100644 --- a/src/ort_genai_c.h +++ b/src/ort_genai_c.h @@ -446,6 +446,21 @@ OGA_EXPORT void OGA_API_CALL OgaDestroyGenerator(OgaGenerator* generator); * \return True if the generator has finished generating all the sequences, false otherwise. */ OGA_EXPORT bool OGA_API_CALL OgaGenerator_IsDone(OgaGenerator* generator); + +/** + * \brief Returns true if the generator is done because it hit the EOS token id after generating all the sequences. + * \param[in] generator The generator to check if it is done with generating all sequences. + * \return True if the generator has hit the EOS token id, false otherwise. + */ +OGA_EXPORT bool OGA_API_CALL OgaGenerator_HitEOS(OgaGenerator* generator); + +/** + * \brief Returns true if the generator is done because it hit the maximum length after generating all the sequences. + * \param[in] generator The generator to check if it is done with generating all sequences. + * \return True if the generator has hit the maximum length, false otherwise. + */ +OGA_EXPORT bool OGA_API_CALL OgaGenerator_HitMaxLength(OgaGenerator* generator); + OGA_EXPORT bool OGA_API_CALL OgaGenerator_IsSessionTerminated(const OgaGenerator* generator); /** diff --git a/src/python/python.cpp b/src/python/python.cpp index 547a116b0f..2fac70b78b 100644 --- a/src/python/python.cpp +++ b/src/python/python.cpp @@ -260,6 +260,14 @@ struct PyGenerator { return generator_->IsDone(); } + bool HitEOS() { + return generator_->HitEOS(); + } + + bool HitMaxLength() { + return generator_->HitMaxLength(); + } + void SetActiveAdapter(OgaAdapters& adapters, const std::string& adapter_name) { generator_->SetActiveAdapter(adapters, adapter_name.c_str()); } diff --git a/src/search.h b/src/search.h index 4c47d66a0a..756d02bcd0 100644 --- a/src/search.h +++ b/src/search.h @@ -19,7 +19,8 @@ struct Search : LeakChecked { virtual DeviceSpan GetLogits() const = 0; virtual void SetLogits(DeviceSpan logits) = 0; virtual bool IsDone() const = 0; - virtual bool IsMaxLength() const = 0; + virtual bool HitEOS() const = 0; + virtual bool HitMaxLength() const = 0; virtual void ResetDone() = 0; virtual void SelectTop() = 0; @@ -46,7 +47,8 @@ struct Search_Cpu : Search { DeviceSpan GetSequenceLengths() override { return sequence_lengths_; } bool IsDone() const override { return done_; } - bool IsMaxLength() const override { return hit_max_length_; } + bool HitEOS() const override { return hit_eos_; } + bool HitMaxLength() const override { return hit_max_length_; } void ResetDone() override; DeviceSpan GetLogits() const override; void SetLogits(DeviceSpan logits) override; From b12e2602c29c8421d15f68bad750085e5302de74 Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Thu, 18 Dec 2025 07:05:08 +0000 Subject: [PATCH 03/38] Add language bindings and update unit tests --- README.md | 18 +++--- examples/c/src/model_chat.cpp | 4 +- examples/c/src/model_qa.cpp | 4 +- examples/c/src/model_vision.cpp | 4 +- examples/c/src/phi4-mm.cpp | 4 +- examples/c/src/whisper.cpp | 4 +- examples/csharp/HelloPhi/Program.cs | 12 ++-- examples/csharp/HelloPhi3V/Program.cs | 4 +- examples/csharp/HelloPhi4MM/Program.cs | 4 +- examples/python/awq-quantized-model.py | 4 +- examples/python/guidance-example.py | 2 + examples/python/model-chat.py | 4 +- examples/python/model-generate.py | 4 +- examples/python/model-qa.py | 4 +- examples/python/model-vision.py | 4 +- examples/python/phi3-qa.py | 4 +- examples/python/phi4-mm.py | 4 +- examples/python/whisper.py | 2 + src/generators.cpp | 4 +- src/python/python.cpp | 2 + src/search.cpp | 4 ++ src/search.h | 1 + test/c_api_tests.cpp | 64 +++++++++---------- test/csharp/TestOnnxRuntimeGenAIAPI.cs | 32 +++++----- test/model_tests.cpp | 40 ++++++------ .../ios_package_uitest_cpp_api.mm | 4 +- .../macos_package_uitest_cpp_api.mm | 4 +- test/python/test_onnxruntime_genai_api.py | 52 +++++++-------- test/python/test_onnxruntime_genai_e2e.py | 4 +- 29 files changed, 156 insertions(+), 145 deletions(-) diff --git a/README.md b/README.md index 0c2f6b1ebc..245de8f0b1 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # ONNX Runtime GenAI -Note: between `v0.11.0` and `v0.10.1`, there is a breaking API usage change to improve model quality during multi-turn conversations. +Note: between `v0.11.5` and `v0.10.1`, there is a breaking API usage change to improve model quality during multi-turn conversations. Previously, the decoding loop could be written as follows. @@ -11,12 +11,12 @@ while not IsDone(): PrintLastToken() ``` -In 0.11.0, the decoding loop should now be written as follows. +In 0.11.5, the decoding loop should now be written as follows. ``` -while True: +while not IsDone(): GenerateToken() - if IsDone(): + if HitEOS(): break GetLastToken() PrintLastToken() @@ -38,7 +38,7 @@ See documentation at the [ONNX Runtime website](https://onnxruntime.ai/docs/gena |Support matrix|Supported now|Under development|On the roadmap| | -------------- | ------------- | ----------------- | -------------- | -| Model architectures | AMD OLMo
ChatGLM
DeepSeek
ERNIE 4.5
Gemma
gpt-oss
Granite
Llama
Mistral
Nemotron
Phi (language + vision)
Qwen
SmolLM3
Whisper | Stable diffusion | Multi-modal models | +| Model architectures | AMD OLMo
ChatGLM
DeepSeek
ERNIE 4.5
Fara
Gemma
gpt-oss
Granite
Llama
Mistral
Nemotron
Phi (language + vision)
Qwen
SmolLM3
Whisper | Stable diffusion | Multi-modal models | | API| Python
C#
C/C++
Java ^ | Objective-C || | Platform | Linux
Windows
Mac ^
Android ^ || iOS ||| | Architecture | x86
x64
Arm64 ~ |||| @@ -100,9 +100,9 @@ See [installation instructions](https://onnxruntime.ai/docs/genai/howto/install) try: generator.append_tokens(input_tokens) - while True: + while not generator.is_done(): generator.generate_next_token() - if generator.is_done(): + if generator.hit_eos(): break new_token = generator.get_next_tokens()[0] print(tokenizer_stream.decode(new_token), end='', flush=True) @@ -118,13 +118,13 @@ See [installation instructions](https://onnxruntime.ai/docs/genai/howto/install) Due to the evolving nature of this project and ongoing feature additions, examples in the `main` branch may not always align with the latest stable release. This section outlines how to ensure compatibility between the examples and the corresponding version. The majority of the steps would remain same. Just the package installation and the model example file would change. ### Stable version -Install the package according to the [installation instructions](https://onnxruntime.ai/docs/genai/howto/install). Let's say you installed the 0.10.1 version of ONNX Runtime GenAI, so the instructions would look like this: +Install the package according to the [installation instructions](https://onnxruntime.ai/docs/genai/howto/install). Let's say you installed the 0.11.5 version of ONNX Runtime GenAI, so the instructions would look like this: ```bash # Clone the repo git clone https://github.com/microsoft/onnxruntime-genai.git && cd onnxruntime-genai # Checkout the branch for the version you are using -git checkout v0.10.1 +git checkout v0.11.5 cd examples ``` diff --git a/examples/c/src/model_chat.cpp b/examples/c/src/model_chat.cpp index d8d6b7a882..83981d8825 100644 --- a/examples/c/src/model_chat.cpp +++ b/examples/c/src/model_chat.cpp @@ -91,7 +91,7 @@ void CXX_API(const char* model_path, const char* execution_provider) { const auto current_token_count = generator->GetSequenceCount(0); try { - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); if (is_first_token) { @@ -99,7 +99,7 @@ void CXX_API(const char* model_path, const char* execution_provider) { is_first_token = false; } - if (generator->IsDone()) { + if (generator->HitEOS()) { break; } diff --git a/examples/c/src/model_qa.cpp b/examples/c/src/model_qa.cpp index 9c06659057..b13cce89f0 100644 --- a/examples/c/src/model_qa.cpp +++ b/examples/c/src/model_qa.cpp @@ -82,7 +82,7 @@ void CXX_API(const char* model_path, const char* execution_provider) { generator->AppendTokenSequences(*sequences); try { - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); if (is_first_token) { @@ -90,7 +90,7 @@ void CXX_API(const char* model_path, const char* execution_provider) { is_first_token = false; } - if (generator->IsDone()) { + if (generator->HitEOS()) { break; } diff --git a/examples/c/src/model_vision.cpp b/examples/c/src/model_vision.cpp index 1be5de0470..7b40a7b410 100644 --- a/examples/c/src/model_vision.cpp +++ b/examples/c/src/model_vision.cpp @@ -90,10 +90,10 @@ void CXX_API(const char* model_path, const char* execution_provider) { auto generator = OgaGenerator::Create(*model, *params); generator->SetInputs(*input_tensors); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { + if (generator->HitEOS()) { break; } diff --git a/examples/c/src/phi4-mm.cpp b/examples/c/src/phi4-mm.cpp index ab0be4d77c..be600f6c51 100644 --- a/examples/c/src/phi4-mm.cpp +++ b/examples/c/src/phi4-mm.cpp @@ -101,10 +101,10 @@ void CXX_API(const char* model_path, const char* execution_provider) { auto generator = OgaGenerator::Create(*model, *params); generator->SetInputs(*input_tensors); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { + if (generator->HitEOS()) { break; } diff --git a/examples/c/src/whisper.cpp b/examples/c/src/whisper.cpp index 48a657972c..2824a752eb 100644 --- a/examples/c/src/whisper.cpp +++ b/examples/c/src/whisper.cpp @@ -57,9 +57,9 @@ void CXX_API(const char* model_path, int32_t num_beams) { auto generator = OgaGenerator::Create(*model, *params); generator->SetInputs(*inputs); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { + if (generator->HitEOS()) { break; } } diff --git a/examples/csharp/HelloPhi/Program.cs b/examples/csharp/HelloPhi/Program.cs index c23ab9907c..364cd20aa7 100644 --- a/examples/csharp/HelloPhi/Program.cs +++ b/examples/csharp/HelloPhi/Program.cs @@ -127,10 +127,10 @@ static string GetPrompt(bool interactive) using var generator = new Generator(model, generatorParams); generator.AppendTokenSequences(sequences); var watch = System.Diagnostics.Stopwatch.StartNew(); - while (true) + while (!generator.IsDone()) { generator.GenerateNextToken(); - if (generator.IsDone()) + if (generator.HitEOS()) { break; } @@ -155,10 +155,10 @@ static string GetPrompt(bool interactive) using var generator = new Generator(model, generatorParams); generator.AppendTokenSequences(sequences); var watch = System.Diagnostics.Stopwatch.StartNew(); - while (true) + while (!generator.IsDone()) { generator.GenerateNextToken(); - if (generator.IsDone()) + if (generator.HitEOS()) { break; } @@ -196,10 +196,10 @@ static string GetPrompt(bool interactive) var sequences = tokenizer.Encode(tokenizer.ApplyChatTemplate("", messages, "", true)); var watch = System.Diagnostics.Stopwatch.StartNew(); generator.AppendTokenSequences(sequences); - while (true) + while (!generator.IsDone()) { generator.GenerateNextToken(); - if (generator.IsDone()) + if (generator.HitEOS()) { break; } diff --git a/examples/csharp/HelloPhi3V/Program.cs b/examples/csharp/HelloPhi3V/Program.cs index 09e1038bd5..bb80aad981 100644 --- a/examples/csharp/HelloPhi3V/Program.cs +++ b/examples/csharp/HelloPhi3V/Program.cs @@ -168,10 +168,10 @@ void PrintUsage() using var generator = new Generator(model, generatorParams); generator.SetInputs(inputTensors); var watch = System.Diagnostics.Stopwatch.StartNew(); - while (true) + while (!generator.IsDone()) { generator.GenerateNextToken(); - if (generator.IsDone()) + if (generator.HitEOS()) { break; } diff --git a/examples/csharp/HelloPhi4MM/Program.cs b/examples/csharp/HelloPhi4MM/Program.cs index bcec8d714f..1b8f27d459 100644 --- a/examples/csharp/HelloPhi4MM/Program.cs +++ b/examples/csharp/HelloPhi4MM/Program.cs @@ -215,10 +215,10 @@ void PrintUsage() using var generator = new Generator(model, generatorParams); generator.SetInputs(inputTensors); var watch = System.Diagnostics.Stopwatch.StartNew(); - while (true) + while (!generator.IsDone()) { generator.GenerateNextToken(); - if (generator.IsDone()) + if (generator.HitEOS()) { break; } diff --git a/examples/python/awq-quantized-model.py b/examples/python/awq-quantized-model.py index fd9991b3c7..b960ad8359 100644 --- a/examples/python/awq-quantized-model.py +++ b/examples/python/awq-quantized-model.py @@ -108,9 +108,9 @@ def run_model(args): print("Output: ", end="", flush=True) try: - while True: + while not generator.is_done(): generator.generate_next_token() - if generator.is_done(): + if generator.hit_eos(): break new_token = generator.get_next_tokens()[0] diff --git a/examples/python/guidance-example.py b/examples/python/guidance-example.py index 3d9a068783..9bcf5b1d51 100644 --- a/examples/python/guidance-example.py +++ b/examples/python/guidance-example.py @@ -54,6 +54,8 @@ def main(args): full_seq_str = "" while not generator.is_done(): generator.generate_next_token() + if generator.hit_eos(): + break # NOTE: since get_next_tokens returns only the last token, we'll need to use get_sequence instead # new_tokens = generator.get_next_tokens()[0] diff --git a/examples/python/model-chat.py b/examples/python/model-chat.py index 44c9add4ff..5412f4b009 100644 --- a/examples/python/model-chat.py +++ b/examples/python/model-chat.py @@ -216,14 +216,14 @@ def main(args): print("Output: ", end="", flush=True) try: - while True: + while not generator.is_done(): generator.generate_next_token() if args.timings: if first: first_token_timestamp = time.time() first = False - if generator.is_done(): + if generator.hit_eos(): break new_token = generator.get_next_tokens()[0] diff --git a/examples/python/model-generate.py b/examples/python/model-generate.py index 649d393c4a..e8b6fefc26 100644 --- a/examples/python/model-generate.py +++ b/examples/python/model-generate.py @@ -99,9 +99,9 @@ def main(args): if args.verbose: print("Generating tokens ...\n") start_time = time.time() - while True: + while not generator.is_done(): generator.generate_next_token() - if generator.is_done(): + if generator.hit_eos(): break run_time = time.time() - start_time diff --git a/examples/python/model-qa.py b/examples/python/model-qa.py index cfd31cd946..813e709295 100644 --- a/examples/python/model-qa.py +++ b/examples/python/model-qa.py @@ -207,14 +207,14 @@ def main(args): print("Output: ", end="", flush=True) try: - while True: + while not generator.is_done(): generator.generate_next_token() if args.timings: if first: first_token_timestamp = time.time() first = False - if generator.is_done(): + if generator.hit_eos(): break new_token = generator.get_next_tokens()[0] diff --git a/examples/python/model-vision.py b/examples/python/model-vision.py index acd95f1ff1..cf9cc2f4c9 100644 --- a/examples/python/model-vision.py +++ b/examples/python/model-vision.py @@ -150,9 +150,9 @@ def run(args: argparse.Namespace): generator.set_inputs(inputs) start_time = time.time() - while True: + while not generator.is_done(): generator.generate_next_token() - if generator.is_done(): + if generator.hit_eos(): break new_token = generator.get_next_tokens()[0] diff --git a/examples/python/phi3-qa.py b/examples/python/phi3-qa.py index e0296d1905..bea9804e5d 100644 --- a/examples/python/phi3-qa.py +++ b/examples/python/phi3-qa.py @@ -75,14 +75,14 @@ def main(args): print("Output: ", end="", flush=True) try: - while True: + while not generator.is_done(): generator.generate_next_token() if args.timings: if first: first_token_timestamp = time.time() first = False - if generator.is_done(): + if generator.hit_eos(): break new_token = generator.get_next_tokens()[0] diff --git a/examples/python/phi4-mm.py b/examples/python/phi4-mm.py index 80a1413942..822a4f0a99 100644 --- a/examples/python/phi4-mm.py +++ b/examples/python/phi4-mm.py @@ -145,9 +145,9 @@ def run(args: argparse.Namespace): generator.set_inputs(inputs) start_time = time.time() - while True: + while not generator.is_done(): generator.generate_next_token() - if generator.is_done(): + if generator.hit_eos(): break new_token = generator.get_next_tokens()[0] diff --git a/examples/python/whisper.py b/examples/python/whisper.py index f460337996..efc382d8cc 100644 --- a/examples/python/whisper.py +++ b/examples/python/whisper.py @@ -70,6 +70,8 @@ def run(args: argparse.Namespace): while not generator.is_done(): generator.generate_next_token() + if generator.hit_eos(): + break print() transcriptions = [] diff --git a/src/generators.cpp b/src/generators.cpp index 80b73809da..6adcb8006f 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -472,11 +472,11 @@ bool Generator::IsDone() { return is_done; } -bool Generator::HitEOS() { +bool Generator::HitEOS() const { return search_->HitEOS(); } -bool Generator::HitMaxLength() { +bool Generator::HitMaxLength() const { return search_->HitMaxLength(); } diff --git a/src/python/python.cpp b/src/python/python.cpp index 2fac70b78b..3ce581bd1d 100644 --- a/src/python/python.cpp +++ b/src/python/python.cpp @@ -463,6 +463,8 @@ PYBIND11_MODULE(onnxruntime_genai, m) { pybind11::class_(m, "Generator") .def(pybind11::init()) .def("is_done", &PyGenerator::IsDone) + .def("hit_eos", &PyGenerator::HitEOS) + .def("hit_max_length", &PyGenerator::HitMaxLength) .def("get_input", &PyGenerator::GetInput) .def("get_output", &PyGenerator::GetOutput) .def("set_inputs", &PyGenerator::SetInputs) diff --git a/src/search.cpp b/src/search.cpp index 92af7b96e7..24d5d97b47 100644 --- a/src/search.cpp +++ b/src/search.cpp @@ -54,6 +54,10 @@ void Search_Cpu::ResetDone() { done_ = false; hit_eos_ = false; hit_max_length_ = false; +} + +void GreedySearch_Cpu::ResetDone() { + Search_Cpu::ResetDone(); not_done_count_ = params_->search.batch_size; memset(eos_seen_.data(), 0, eos_seen_.size_bytes()); } diff --git a/src/search.h b/src/search.h index 756d02bcd0..f81ac93e85 100644 --- a/src/search.h +++ b/src/search.h @@ -85,6 +85,7 @@ struct GreedySearch_Cpu : Search_Cpu { // Used by continuous decoding search. void AppendTokens(DeviceSpan& next_tokens) override; void RewindTo(size_t index) override; + void ResetDone() override; protected: void SetNextToken(size_t batch_id, int32_t token); diff --git a/test/c_api_tests.cpp b/test/c_api_tests.cpp index 08af900b56..b40af6ddda 100644 --- a/test/c_api_tests.cpp +++ b/test/c_api_tests.cpp @@ -338,9 +338,9 @@ TEST(CAPITests, EndToEndPhiBatch) { auto generator = OgaGenerator::Create(*model, *params); generator->AppendTokenSequences(*input_sequences); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { + if (generator->HitEOS()) { break; } } @@ -519,9 +519,9 @@ TEST(CAPITests, EndToEndPhi) { auto generator = OgaGenerator::Create(*model, *params); generator->AppendTokenSequences(*input_sequence); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { + if (generator->HitEOS()) { break; } } @@ -561,9 +561,9 @@ TEST(CAPITests, EndToEndPhiEOSPAD) { auto generator = OgaGenerator::Create(*model, *params); generator->AppendTokenSequences(*input_sequence); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { + if (generator->HitEOS()) { break; } } @@ -658,9 +658,9 @@ TEST(CAPITests, LoadModelFromMemory) { auto generator = OgaGenerator::Create(*model, *params); generator->AppendTokenSequences(*input_sequence); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { + if (generator->HitEOS()) { break; } } @@ -738,9 +738,9 @@ TEST(CAPITests, GreedySearchGptFp32CAPI) { auto generator = OgaGenerator::Create(*model, *params); generator->AppendTokens(input_ids.data(), input_ids.size()); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { + if (generator->HitEOS()) { break; } } @@ -903,9 +903,9 @@ TEST(CAPITests, SetTerminate) { auto GenerateOutput = [](OgaGenerator* generator, std::unique_ptr tokenizer_stream) { EXPECT_THROW({ - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { + if (generator->HitEOS()) { break; } } }, std::runtime_error); @@ -968,9 +968,9 @@ struct Phi2Test { auto generator = OgaGenerator::Create(*model_, *params_); generator->AppendTokenSequences(*input_sequences_); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { + if (generator->HitEOS()) { break; } } @@ -1130,9 +1130,9 @@ TEST(CAPITests, AdaptersTest) { auto generator = OgaGenerator::Create(*model, *params); generator->AppendTokenSequences(*input_sequences); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { + if (generator->HitEOS()) { break; } } @@ -1156,9 +1156,9 @@ TEST(CAPITests, AdaptersTest) { generator->SetActiveAdapter(*adapters, "adapters_a_and_b"); generator->AppendTokenSequences(*input_sequences); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { + if (generator->HitEOS()) { break; } } @@ -1210,9 +1210,9 @@ TEST(CAPITests, AdaptersTestMultipleAdapters) { generator->SetActiveAdapter(*adapters, "adapter_b"); generator->AppendTokenSequences(*input_sequences); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { + if (generator->HitEOS()) { break; } } @@ -1256,9 +1256,9 @@ TEST(CAPITests, BatchedRewindGptFp32CAPI) { auto generator = OgaGenerator::Create(*model, *params); generator->AppendTokens(input_ids.data(), input_ids.size()); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { + if (generator->HitEOS()) { break; } } @@ -1278,9 +1278,9 @@ TEST(CAPITests, BatchedRewindGptFp32CAPI) { generator->RewindTo(0); generator->AppendTokens(input_ids.data(), input_ids.size()); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { + if (generator->HitEOS()) { break; } } @@ -1316,9 +1316,9 @@ TEST(CAPITests, RewindGptFp32CAPI) { auto generator = OgaGenerator::Create(*model, *params); generator->AppendTokens(input_ids.data(), input_ids.size()); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { + if (generator->HitEOS()) { break; } } @@ -1334,9 +1334,9 @@ TEST(CAPITests, RewindGptFp32CAPI) { // Rewind to length 5 and verify same output generator->RewindTo(5); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { + if (generator->HitEOS()) { break; } } @@ -1353,9 +1353,9 @@ TEST(CAPITests, RewindGptFp32CAPI) { std::vector next_ids{731, 731}; generator->AppendTokens(next_ids.data(), next_ids.size()); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { + if (generator->HitEOS()) { break; } } @@ -1386,9 +1386,9 @@ TEST(CAPITests, SetGuidance) { auto generator = OgaGenerator::Create(*model, *params); generator->AppendTokenSequences(*input_sequences); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { + if (generator->HitEOS()) { break; } } diff --git a/test/csharp/TestOnnxRuntimeGenAIAPI.cs b/test/csharp/TestOnnxRuntimeGenAIAPI.cs index 3b1c72d395..0496c39ac8 100644 --- a/test/csharp/TestOnnxRuntimeGenAIAPI.cs +++ b/test/csharp/TestOnnxRuntimeGenAIAPI.cs @@ -161,10 +161,10 @@ public void TestGreedySearch() generator.AppendTokens(inputIDs); Assert.False(generator.IsDone()); - while (true) + while (!generator.IsDone()) { generator.GenerateNextToken(); - if (generator.IsDone()) + if (generator.HitEOS()) { break; } @@ -216,10 +216,10 @@ public void TestLoadModelFromMemory() generator.AppendTokens(inputIDs); Assert.False(generator.IsDone()); - while (true) + while (!generator.IsDone()) { generator.GenerateNextToken(); - if (generator.IsDone()) + if (generator.HitEOS()) { break; } @@ -296,10 +296,10 @@ public void TestTopKSearch() generator.AppendTokenSequences(sequences); - while (true) + while (!generator.IsDone()) { generator.GenerateNextToken(); - if (generator.IsDone()) + if (generator.HitEOS()) { break; } @@ -359,10 +359,10 @@ public void TestTopPSearch() generator.AppendTokenSequences(sequences); - while (true) + while (!generator.IsDone()) { generator.GenerateNextToken(); - if (generator.IsDone()) + if (generator.HitEOS()) { break; } @@ -424,10 +424,10 @@ public void TestTopKTopPSearch() generator.AppendTokenSequences(sequences); - while (true) + while (!generator.IsDone()) { generator.GenerateNextToken(); - if (generator.IsDone()) + if (generator.HitEOS()) { break; } @@ -635,10 +635,10 @@ public void TestPhi2() generator.AppendTokenSequences(sequences); - while (true) + while (!generator.IsDone()) { generator.GenerateNextToken(); - if (generator.IsDone()) + if (generator.HitEOS()) { break; } @@ -745,10 +745,10 @@ public void TestAdapters() using var generator = new Generator(model, genParams); generator.AppendTokenSequences(sequences); - while (true) + while (!generator.IsDone()) { generator.GenerateNextToken(); - if (generator.IsDone()) + if (generator.HitEOS()) { break; } @@ -777,10 +777,10 @@ public void TestAdapters() using var generator = new Generator(model, genParams); generator.SetActiveAdapter(adapters, "adapters_a_and_b"); generator.AppendTokenSequences(sequences); - while (true) + while (!generator.IsDone()) { generator.GenerateNextToken(); - if (generator.IsDone()) + if (generator.HitEOS()) { break; } diff --git a/test/model_tests.cpp b/test/model_tests.cpp index 56ce5696e0..7eecc499c0 100644 --- a/test/model_tests.cpp +++ b/test/model_tests.cpp @@ -80,9 +80,9 @@ TEST(ModelTests, GreedySearchGptFp32) { auto generator = OgaGenerator::Create(*model, *params); generator->AppendTokens(input_ids); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { + if (generator->HitEOS()) { break; } } @@ -124,9 +124,9 @@ TEST(ModelTests, BeamSearchGptFp32) { auto generator = OgaGenerator::Create(*model, *params); generator->AppendTokens(input_ids); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { + if (generator->HitEOS()) { break; } } @@ -161,9 +161,9 @@ void Test_GreedySearch_Gpt_Cuda(const char* model_path, const char* model_label) auto generator = OgaGenerator::Create(*model, *params); generator->AppendTokens(input_ids); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { + if (generator->HitEOS()) { break; } } @@ -186,9 +186,9 @@ void Test_GreedySearch_Gpt_Cuda(const char* model_path, const char* model_label) generator = OgaGenerator::Create(*model, *params); generator->AppendTokens(input_ids); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { + if (generator->HitEOS()) { break; } } @@ -201,9 +201,9 @@ void Test_GreedySearch_Gpt_Cuda(const char* model_path, const char* model_label) generator->RewindTo(3); std::vector next_ids{731, 731}; generator->AppendTokens(next_ids); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { + if (generator->HitEOS()) { break; } } @@ -247,9 +247,9 @@ void Test_BeamSearch_Gpt_Cuda(const char* model_path, const char* model_label) { auto generator = OgaGenerator::Create(*model, *params); generator->AppendTokens(input_ids); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { + if (generator->HitEOS()) { break; } } @@ -299,9 +299,9 @@ void Test_GreedySearch_Phi3_NvTensorRtRtx(const char* model_path, const char* mo auto generator = OgaGenerator::Create(*model, *params); generator->AppendTokens(input_ids); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { + if (generator->HitEOS()) { break; } } @@ -352,9 +352,9 @@ void Test_OutOfPlaceKvCache_Phi3_NvTensorRtRtx(const char* model_path, const cha auto generator = OgaGenerator::Create(*model, *params); generator->AppendTokens(input_ids); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { + if (generator->HitEOS()) { break; } } @@ -394,9 +394,9 @@ Print all primes between 1 and n auto generator = OgaGenerator::Create(*model, *params); generator->AppendTokens(tokens->Get(0)); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { + if (generator->HitEOS()) { break; } } @@ -428,9 +428,9 @@ Print all primes between 1 and n auto generator = OgaGenerator::Create(*model, *params); generator->AppendTokens(tokens->Get(0)); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { + if (generator->HitEOS()) { break; } } diff --git a/test/platform/apple/apple_package_test/ios_package_testUITests/ios_package_uitest_cpp_api.mm b/test/platform/apple/apple_package_test/ios_package_testUITests/ios_package_uitest_cpp_api.mm index 9ff3bee78c..0123b6e7f6 100644 --- a/test/platform/apple/apple_package_test/ios_package_testUITests/ios_package_uitest_cpp_api.mm +++ b/test/platform/apple/apple_package_test/ios_package_testUITests/ios_package_uitest_cpp_api.mm @@ -54,9 +54,9 @@ - (void)testCppAPI_Basic { auto generator = OgaGenerator::Create(*model, *params); generator->AppendTokenSequences(*sequences); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { + if (generator->HitEOS()) { break; } } diff --git a/test/platform/apple/apple_package_test/macos_package_testUITests/macos_package_uitest_cpp_api.mm b/test/platform/apple/apple_package_test/macos_package_testUITests/macos_package_uitest_cpp_api.mm index a4e5fdfc63..30afa758f0 100644 --- a/test/platform/apple/apple_package_test/macos_package_testUITests/macos_package_uitest_cpp_api.mm +++ b/test/platform/apple/apple_package_test/macos_package_testUITests/macos_package_uitest_cpp_api.mm @@ -54,9 +54,9 @@ - (void)testCppAPI_Basic { auto generator = OgaGenerator::Create(*model, *params); generator->AppendTokenSequences(*sequences); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { + if (generator->HitEOS()) { break; } } diff --git a/test/python/test_onnxruntime_genai_api.py b/test/python/test_onnxruntime_genai_api.py index cd6dd43c2d..655f02c3b9 100644 --- a/test/python/test_onnxruntime_genai_api.py +++ b/test/python/test_onnxruntime_genai_api.py @@ -148,14 +148,14 @@ def test_greedy_search(test_data_path, relative_model_path): generator = og.Generator(model, search_params) generator.append_tokens(np.array([[0, 0, 0, 52], [0, 0, 195, 731]], dtype=np.int32)) - while True: + while not generator.is_done(): # Test getting/setting logits logits = generator.get_logits() generator.set_logits(logits) generator.set_logits(logits) # twice just to be sure buffer is still valid generator.generate_next_token() - if generator.is_done(): + if generator.hit_eos(): break expected_sequence = np.array( @@ -193,9 +193,9 @@ def test_rewind_cuda(test_data_path, relative_model_path): generator = og.Generator(model, search_params) generator.append_tokens(np.array([[0, 0, 195, 731]], dtype=np.int32)) - while True: + while not generator.is_done(): generator.generate_next_token() - if generator.is_done(): + if generator.hit_eos(): break assert generator.get_sequence(0) is not None @@ -203,9 +203,9 @@ def test_rewind_cuda(test_data_path, relative_model_path): generator.rewind_to(3) generator.append_tokens(np.array([[731, 731]], dtype=np.int32)) - while True: + while not generator.is_done(): generator.generate_next_token() - if generator.is_done(): + if generator.hit_eos(): break assert generator.get_sequence(0) is not None @@ -217,9 +217,9 @@ def test_rewind_cuda(test_data_path, relative_model_path): generator = og.Generator(model, search_params) generator.append_tokens(np.array([[0, 0, 0, 52], [0, 0, 195, 731], [64, 65, 66, 67]], dtype=np.int32)) - while True: + while not generator.is_done(): generator.generate_next_token() - if generator.is_done(): + if generator.hit_eos(): break for i in range(batch_size): @@ -233,9 +233,9 @@ def test_rewind_cuda(test_data_path, relative_model_path): dtype=np.int32, ) ) - while True: + while not generator.is_done(): generator.generate_next_token() - if generator.is_done(): + if generator.hit_eos(): break for i in range(batch_size): @@ -262,9 +262,9 @@ def test_rewind(test_data_path, relative_model_path): generator = og.Generator(model, search_params) generator.append_tokens(np.array([[0, 0, 195, 731]], dtype=np.int32)) - while True: + while not generator.is_done(): generator.generate_next_token() - if generator.is_done(): + if generator.hit_eos(): break assert np.array_equal(expected_sequence, generator.get_sequence(0)) @@ -272,9 +272,9 @@ def test_rewind(test_data_path, relative_model_path): generator.rewind_to(3) generator.append_tokens(np.array([[731, 731]], dtype=np.int32)) - while True: + while not generator.is_done(): generator.generate_next_token() - if generator.is_done(): + if generator.hit_eos(): break assert np.array_equal(expected_sequence, generator.get_sequence(0)) @@ -403,9 +403,9 @@ def test_batching(device, phi2_for): generator = og.Generator(model, params) generator.append_tokens(tokenizer.encode_batch(prompts)) - while True: + while not generator.is_done(): generator.generate_next_token() - if generator.is_done(): + if generator.hit_eos(): break for i in range(len(prompts)): print(tokenizer.decode(generator.get_sequence(0))) @@ -434,9 +434,9 @@ def test_e2e(device, phi2_for): generator = og.Generator(model, params) generator.append_tokens(tokenizer.encode_batch(prompts)) - while True: + while not generator.is_done(): generator.generate_next_token() - if generator.is_done(): + if generator.hit_eos(): break for i in range(len(prompts)): print(tokenizer.decode(generator.get_sequence(0))) @@ -469,9 +469,9 @@ def test_load_model_from_memory(device, wrapper_bytes_function, phi2_for): generator = og.Generator(model, params) generator.append_tokens(tokenizer.encode_batch(prompts)) - while True: + while not generator.is_done(): generator.generate_next_token() - if generator.is_done(): + if generator.hit_eos(): break for i in range(len(prompts)): print(tokenizer.decode(generator.get_sequence(0))) @@ -648,9 +648,9 @@ def _split(onnx_model_path: os.PathLike, output_dir: os.PathLike): generator = og.Generator(model, params) generator.append_tokens(tokenizer.encode_batch(prompts)) - while True: + while not generator.is_done(): generator.generate_next_token() - if generator.is_done(): + if generator.hit_eos(): break expected_output = [ @@ -848,9 +848,9 @@ def _export_adapter(adapter, adapter_file_name): generator.set_active_adapter(adapters, f"adapter_{i}") generator.append_tokens(tokenizer.encode_batch(prompts)) - while True: + while not generator.is_done(): generator.generate_next_token() - if generator.is_done(): + if generator.hit_eos(): break @@ -935,9 +935,9 @@ def _prepare_model(test_data_path): else: generator.append_tokens(tokenizer.encode_batch(prompts)) - while True: + while not generator.is_done(): generator.generate_next_token() - if generator.is_done(): + if generator.hit_eos(): break diff --git a/test/python/test_onnxruntime_genai_e2e.py b/test/python/test_onnxruntime_genai_e2e.py index 643882ae5c..167658f703 100644 --- a/test/python/test_onnxruntime_genai_e2e.py +++ b/test/python/test_onnxruntime_genai_e2e.py @@ -31,9 +31,9 @@ def run_model(model_path: str | bytes | os.PathLike): generator = og.Generator(model, params) generator.append_tokens(sequences) - while True: + while not generator.is_done(): generator.generate_next_token() - if generator.is_done(): + if generator.hit_eos(): break for i in range(3): From 795719d820eb91f6f5014d6a716d742984542d0d Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Thu, 18 Dec 2025 08:26:15 +0000 Subject: [PATCH 04/38] Update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 245de8f0b1..87dfebaa31 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ while not IsDone(): PrintLastToken() ``` -Please read [this PR's description](https://github.com/microsoft/onnxruntime-genai/pull/1849) for more information. +Please read [this PR's description](https://github.com/microsoft/onnxruntime-genai/pull/1925) for more information. ## Status From 4a13b80270d83faf8441713cd2665744dd52d421 Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Thu, 18 Dec 2025 17:54:33 +0000 Subject: [PATCH 05/38] Fix Java build and initialize pointers --- src/cuda/search_cuda.cpp | 2 ++ src/java/src/main/java/ai/onnxruntime/genai/Generator.java | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/src/cuda/search_cuda.cpp b/src/cuda/search_cuda.cpp index 34d74a36ee..9e48537bc8 100644 --- a/src/cuda/search_cuda.cpp +++ b/src/cuda/search_cuda.cpp @@ -26,6 +26,8 @@ Search_Cuda::Search_Cuda(const GeneratorParams& params) eos_seen_buffer_ = CudaMallocArray(batch_beam_size, &eos_seen_); done_cpu_ = CudaMallocHostArray(1); + hit_eos_cpu_ = CudaMallocHostArray(1); + hit_max_length_cpu_ = CudaMallocHostArray(1); eos_token_ids_ = params.p_device->Allocate(params.config.model.eos_token_id.size()); copy(std::span{params.config.model.eos_token_id}, eos_token_ids_.CpuSpan()); diff --git a/src/java/src/main/java/ai/onnxruntime/genai/Generator.java b/src/java/src/main/java/ai/onnxruntime/genai/Generator.java index 78213860d7..0fd0471af8 100644 --- a/src/java/src/main/java/ai/onnxruntime/genai/Generator.java +++ b/src/java/src/main/java/ai/onnxruntime/genai/Generator.java @@ -296,6 +296,10 @@ private native long createGenerator(long modelHandle, long generatorParamsHandle private native boolean isDone(long nativeHandle); + private native boolean hitEOS(long nativeHandle); + + private native boolean hitMaxLength(long nativeHandle); + private native void setModelInput(long nativeHandle, String inputName, long tensorHandle) throws GenAIException; From 3165142d6b36fa08ebbe5d2d2c2cbf25dbd96a3e Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Sat, 20 Dec 2025 00:27:34 +0000 Subject: [PATCH 06/38] Add checks for beam search --- src/cuda/search_cuda.cpp | 4 +++- src/cuda/search_cuda.cu | 7 ++++--- src/cuda/search_cuda.cuh | 4 ++-- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/cuda/search_cuda.cpp b/src/cuda/search_cuda.cpp index 9e48537bc8..f4ebfd8820 100644 --- a/src/cuda/search_cuda.cpp +++ b/src/cuda/search_cuda.cpp @@ -171,7 +171,7 @@ void GreedySearch_Cuda::SampleTopKTopP(int k, float p, float temperature) { // Check for EOS assert(next_tokens_.size() == eos_seen_.size()); - cuda::Launch_CheckForEOSAndPad(next_tokens_.data(), static_cast(next_tokens_.size()), eos_seen_.data(), eos_token_ids_.Span().data(), static_cast(eos_token_ids_.Span().size()), params_->config.model.pad_token_id, done_cpu_.get(), GetStream()); + cuda::Launch_CheckForEOSAndPad(next_tokens_.data(), static_cast(next_tokens_.size()), eos_seen_.data(), eos_token_ids_.Span().data(), static_cast(eos_token_ids_.Span().size()), params_->config.model.pad_token_id, done_cpu_.get(), hit_eos_cpu_.get(), GetStream()); // Append tokens cudaStreamSynchronize(GetStream()); @@ -249,10 +249,12 @@ void GreedySearch_Cuda::AppendTokens(DeviceSpan& next_tokens) { } void BeamSearch_Cuda::AppendTokens(DeviceSpan& next_tokens) { + ResetDone(); auto next_tokens_gpu = next_tokens.Span(); cuda::Launch_ExpandInputSequences(next_tokens_gpu, sequences_.GetNextSequences().Span(), params_->search.batch_size, params_->search.num_beams, sequences_.max_length_, GetStream()); cuda::Launch_ExpandInputSequences(next_tokens_gpu, sequences_.GetSequences().Span(), params_->search.batch_size, params_->search.num_beams, sequences_.max_length_, GetStream()); sequences_.AfterAppendNextTokens(next_tokens, params_->search.batch_size); // next_tokens is batch_size + ResetDone(); cudaStreamSynchronize(GetStream()); } diff --git a/src/cuda/search_cuda.cu b/src/cuda/search_cuda.cu index e72689291b..53c9690a65 100644 --- a/src/cuda/search_cuda.cu +++ b/src/cuda/search_cuda.cu @@ -74,7 +74,7 @@ struct ArgMaxDataImpl : ArgMaxData { cuda_unique_ptr> argmaxen_owner_; }; -__global__ void CheckForEOSAndPad(int32_t* next_tokens, int next_tokens_count, bool* eos_seen, const int* eos_token_ids, int eos_token_count, int pad_token_id, bool* done_cpu) { +__global__ void CheckForEOSAndPad(int32_t* next_tokens, int next_tokens_count, bool* eos_seen, const int* eos_token_ids, int eos_token_count, int pad_token_id, bool* done_cpu, bool* hit_eos_cpu) { for (int batch_id = 0; batch_id < next_tokens_count; ++batch_id) { // If EOS already met, pad if (eos_seen[batch_id]) { @@ -103,13 +103,14 @@ __global__ void CheckForEOSAndPad(int32_t* next_tokens, int next_tokens_count, b } if (batch_id == next_tokens_count) { *done_cpu = true; + *hit_eos_cpu = true; return; } } } -void Launch_CheckForEOSAndPad(int32_t* next_tokens, int next_tokens_count, bool* eos_seen, const int* eos_token_ids, int eos_token_count, int pad_token_id, bool* done_cpu, cudaStream_t stream) { - CheckForEOSAndPad<<<1, 1, 0, stream>>>(next_tokens, next_tokens_count, eos_seen, eos_token_ids, eos_token_count, pad_token_id, done_cpu); +void Launch_CheckForEOSAndPad(int32_t* next_tokens, int next_tokens_count, bool* eos_seen, const int* eos_token_ids, int eos_token_count, int pad_token_id, bool* done_cpu, bool* hit_eos_cpu, cudaStream_t stream) { + CheckForEOSAndPad<<<1, 1, 0, stream>>>(next_tokens, next_tokens_count, eos_seen, eos_token_ids, eos_token_count, pad_token_id, done_cpu, hit_eos_cpu); } __global__ void AddProbsKernel(float* log_probs, diff --git a/src/cuda/search_cuda.cuh b/src/cuda/search_cuda.cuh index 3aa5aa0fc5..d039bacae3 100644 --- a/src/cuda/search_cuda.cuh +++ b/src/cuda/search_cuda.cuh @@ -9,8 +9,8 @@ struct ArgMaxData { virtual ~ArgMaxData() = default; }; -void Launch_CheckForEOSAndPad(int32_t* next_tokens, int next_tokens_count, bool* eos_seen, const int32_t* eos_token_ids, int eos_token_count, bool* done_cpu, cudaStream_t stream); -void Launch_CheckForEOSAndPad(int32_t* next_tokens, int next_tokens_count, bool* eos_seen, const int32_t* eos_token_ids, int eos_token_count, int pad_token_id, bool* done_cpu, cudaStream_t stream); +void Launch_CheckForEOSAndPad(int32_t* next_tokens, int next_tokens_count, bool* eos_seen, const int32_t* eos_token_ids, int eos_token_count, bool* done_cpu, bool* hit_eos_cpu, cudaStream_t stream); +void Launch_CheckForEOSAndPad(int32_t* next_tokens, int next_tokens_count, bool* eos_seen, const int32_t* eos_token_ids, int eos_token_count, int pad_token_id, bool* done_cpu, bool* hit_eos_cpu, cudaStream_t stream); void Launch_ExpandInputSequences(const std::span input_sequences, std::span sequences, int batch_size, int beam_size, int max_length, cudaStream_t stream); void Launch_AppendNextTokensToSequences(std::span next_tokens, std::span sequences, int batch_beam_size, int past_length, int max_length, cudaStream_t stream); void Launch_GetLastTokens(int32_t* next_tokens, const int32_t* sequences, int batch_beam_size, int sequence_length, int max_length, cudaStream_t stream); From 74c4fac9f35e8c8d173a2b011e422fec3e94b619 Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Wed, 21 Jan 2026 08:59:56 +0000 Subject: [PATCH 07/38] Remove HitEOS from examples since GetNextTokens works in streaming mode --- README.md | 2 - examples/c/src/model_chat.cpp | 4 -- examples/c/src/model_qa.cpp | 4 -- examples/c/src/model_vision.cpp | 4 -- examples/c/src/phi4-mm.cpp | 4 -- examples/c/src/whisper.cpp | 3 -- examples/csharp/HelloPhi/Program.cs | 12 ----- examples/csharp/HelloPhi3V/Program.cs | 4 -- examples/csharp/HelloPhi4MM/Program.cs | 4 -- examples/python/awq-quantized-model.py | 3 -- examples/python/guidance-example.py | 2 - examples/python/model-chat.py | 3 -- examples/python/model-generate.py | 2 - examples/python/model-qa.py | 3 -- examples/python/model-vision.py | 3 -- examples/python/phi3-qa.py | 3 -- examples/python/phi4-mm.py | 3 -- examples/python/whisper.py | 2 - test/c_api_tests.cpp | 48 ------------------- test/csharp/TestOnnxRuntimeGenAIAPI.cs | 32 ------------- test/model_tests.cpp | 30 ------------ .../ios_package_uitest_cpp_api.mm | 3 -- .../macos_package_uitest_cpp_api.mm | 3 -- test/python/test_onnxruntime_genai_api.py | 26 ---------- test/python/test_onnxruntime_genai_e2e.py | 2 - 25 files changed, 209 deletions(-) diff --git a/README.md b/README.md index 0f6d16703a..f57edf9248 100644 --- a/README.md +++ b/README.md @@ -80,8 +80,6 @@ See [installation instructions](https://onnxruntime.ai/docs/genai/howto/install) generator.append_tokens(input_tokens) while not generator.is_done(): generator.generate_next_token() - if generator.hit_eos(): - break new_token = generator.get_next_tokens()[0] print(tokenizer_stream.decode(new_token), end='', flush=True) except KeyboardInterrupt: diff --git a/examples/c/src/model_chat.cpp b/examples/c/src/model_chat.cpp index 83981d8825..a4bfac7dc9 100644 --- a/examples/c/src/model_chat.cpp +++ b/examples/c/src/model_chat.cpp @@ -99,10 +99,6 @@ void CXX_API(const char* model_path, const char* execution_provider) { is_first_token = false; } - if (generator->HitEOS()) { - break; - } - const auto num_tokens = generator->GetSequenceCount(0); const auto new_token = generator->GetSequenceData(0)[num_tokens - 1]; std::cout << tokenizer_stream->Decode(new_token) << std::flush; diff --git a/examples/c/src/model_qa.cpp b/examples/c/src/model_qa.cpp index b13cce89f0..191b924fd8 100644 --- a/examples/c/src/model_qa.cpp +++ b/examples/c/src/model_qa.cpp @@ -90,10 +90,6 @@ void CXX_API(const char* model_path, const char* execution_provider) { is_first_token = false; } - if (generator->HitEOS()) { - break; - } - const auto num_tokens = generator->GetSequenceCount(0); const auto new_token = generator->GetSequenceData(0)[num_tokens - 1]; std::cout << tokenizer_stream->Decode(new_token) << std::flush; diff --git a/examples/c/src/model_vision.cpp b/examples/c/src/model_vision.cpp index 7b40a7b410..9d7604cfd2 100644 --- a/examples/c/src/model_vision.cpp +++ b/examples/c/src/model_vision.cpp @@ -93,10 +93,6 @@ void CXX_API(const char* model_path, const char* execution_provider) { while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->HitEOS()) { - break; - } - const auto num_tokens = generator->GetSequenceCount(0); const auto new_token = generator->GetSequenceData(0)[num_tokens - 1]; std::cout << stream->Decode(new_token) << std::flush; diff --git a/examples/c/src/phi4-mm.cpp b/examples/c/src/phi4-mm.cpp index be600f6c51..db5ef5719b 100644 --- a/examples/c/src/phi4-mm.cpp +++ b/examples/c/src/phi4-mm.cpp @@ -104,10 +104,6 @@ void CXX_API(const char* model_path, const char* execution_provider) { while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->HitEOS()) { - break; - } - const auto num_tokens = generator->GetSequenceCount(0); const auto new_token = generator->GetSequenceData(0)[num_tokens - 1]; std::cout << stream->Decode(new_token) << std::flush; diff --git a/examples/c/src/whisper.cpp b/examples/c/src/whisper.cpp index 2824a752eb..32e8c9e029 100644 --- a/examples/c/src/whisper.cpp +++ b/examples/c/src/whisper.cpp @@ -59,9 +59,6 @@ void CXX_API(const char* model_path, int32_t num_beams) { while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->HitEOS()) { - break; - } } for (size_t i = 0; i < static_cast(num_beams * batch_size); ++i) { diff --git a/examples/csharp/HelloPhi/Program.cs b/examples/csharp/HelloPhi/Program.cs index 364cd20aa7..e2f64dfc8f 100644 --- a/examples/csharp/HelloPhi/Program.cs +++ b/examples/csharp/HelloPhi/Program.cs @@ -130,10 +130,6 @@ static string GetPrompt(bool interactive) while (!generator.IsDone()) { generator.GenerateNextToken(); - if (generator.HitEOS()) - { - break; - } } var outputSequence = generator.GetSequence(0); @@ -158,10 +154,6 @@ static string GetPrompt(bool interactive) while (!generator.IsDone()) { generator.GenerateNextToken(); - if (generator.HitEOS()) - { - break; - } Console.Write(tokenizerStream.Decode(generator.GetNextTokens()[0])); } Console.WriteLine(); @@ -199,10 +191,6 @@ static string GetPrompt(bool interactive) while (!generator.IsDone()) { generator.GenerateNextToken(); - if (generator.HitEOS()) - { - break; - } Console.Write(tokenizerStream.Decode(generator.GetNextTokens()[0])); } Console.WriteLine(); diff --git a/examples/csharp/HelloPhi3V/Program.cs b/examples/csharp/HelloPhi3V/Program.cs index bb80aad981..1d32cac199 100644 --- a/examples/csharp/HelloPhi3V/Program.cs +++ b/examples/csharp/HelloPhi3V/Program.cs @@ -171,10 +171,6 @@ void PrintUsage() while (!generator.IsDone()) { generator.GenerateNextToken(); - if (generator.HitEOS()) - { - break; - } Console.Write(stream.Decode(generator.GetNextTokens()[0])); } watch.Stop(); diff --git a/examples/csharp/HelloPhi4MM/Program.cs b/examples/csharp/HelloPhi4MM/Program.cs index 1b8f27d459..ce0ddf359d 100644 --- a/examples/csharp/HelloPhi4MM/Program.cs +++ b/examples/csharp/HelloPhi4MM/Program.cs @@ -218,10 +218,6 @@ void PrintUsage() while (!generator.IsDone()) { generator.GenerateNextToken(); - if (generator.HitEOS()) - { - break; - } Console.Write(stream.Decode(generator.GetNextTokens()[0])); } watch.Stop(); diff --git a/examples/python/awq-quantized-model.py b/examples/python/awq-quantized-model.py index b960ad8359..465935e264 100644 --- a/examples/python/awq-quantized-model.py +++ b/examples/python/awq-quantized-model.py @@ -110,9 +110,6 @@ def run_model(args): try: while not generator.is_done(): generator.generate_next_token() - if generator.hit_eos(): - break - new_token = generator.get_next_tokens()[0] print(tokenizer_stream.decode(new_token), end="", flush=True) except KeyboardInterrupt: diff --git a/examples/python/guidance-example.py b/examples/python/guidance-example.py index 9bcf5b1d51..3d9a068783 100644 --- a/examples/python/guidance-example.py +++ b/examples/python/guidance-example.py @@ -54,8 +54,6 @@ def main(args): full_seq_str = "" while not generator.is_done(): generator.generate_next_token() - if generator.hit_eos(): - break # NOTE: since get_next_tokens returns only the last token, we'll need to use get_sequence instead # new_tokens = generator.get_next_tokens()[0] diff --git a/examples/python/model-chat.py b/examples/python/model-chat.py index 5412f4b009..ca9746b16a 100644 --- a/examples/python/model-chat.py +++ b/examples/python/model-chat.py @@ -223,9 +223,6 @@ def main(args): first_token_timestamp = time.time() first = False - if generator.hit_eos(): - break - new_token = generator.get_next_tokens()[0] print(tokenizer_stream.decode(new_token), end="", flush=True) if args.timings: diff --git a/examples/python/model-generate.py b/examples/python/model-generate.py index e8b6fefc26..86be56922b 100644 --- a/examples/python/model-generate.py +++ b/examples/python/model-generate.py @@ -101,8 +101,6 @@ def main(args): start_time = time.time() while not generator.is_done(): generator.generate_next_token() - if generator.hit_eos(): - break run_time = time.time() - start_time for i in range(len(prompts)): diff --git a/examples/python/model-qa.py b/examples/python/model-qa.py index 813e709295..a9b2cf7802 100644 --- a/examples/python/model-qa.py +++ b/examples/python/model-qa.py @@ -214,9 +214,6 @@ def main(args): first_token_timestamp = time.time() first = False - if generator.hit_eos(): - break - new_token = generator.get_next_tokens()[0] print(tokenizer_stream.decode(new_token), end="", flush=True) if args.timings: diff --git a/examples/python/model-vision.py b/examples/python/model-vision.py index cf9cc2f4c9..e226697e15 100644 --- a/examples/python/model-vision.py +++ b/examples/python/model-vision.py @@ -152,9 +152,6 @@ def run(args: argparse.Namespace): while not generator.is_done(): generator.generate_next_token() - if generator.hit_eos(): - break - new_token = generator.get_next_tokens()[0] print(stream.decode(new_token), end="", flush=True) diff --git a/examples/python/phi3-qa.py b/examples/python/phi3-qa.py index bea9804e5d..645954fee3 100644 --- a/examples/python/phi3-qa.py +++ b/examples/python/phi3-qa.py @@ -82,9 +82,6 @@ def main(args): first_token_timestamp = time.time() first = False - if generator.hit_eos(): - break - new_token = generator.get_next_tokens()[0] print(tokenizer_stream.decode(new_token), end="", flush=True) if args.timings: diff --git a/examples/python/phi4-mm.py b/examples/python/phi4-mm.py index 822a4f0a99..c9aa155cd4 100644 --- a/examples/python/phi4-mm.py +++ b/examples/python/phi4-mm.py @@ -147,9 +147,6 @@ def run(args: argparse.Namespace): while not generator.is_done(): generator.generate_next_token() - if generator.hit_eos(): - break - new_token = generator.get_next_tokens()[0] print(tokenizer_stream.decode(new_token), end="", flush=True) diff --git a/examples/python/whisper.py b/examples/python/whisper.py index efc382d8cc..f460337996 100644 --- a/examples/python/whisper.py +++ b/examples/python/whisper.py @@ -70,8 +70,6 @@ def run(args: argparse.Namespace): while not generator.is_done(): generator.generate_next_token() - if generator.hit_eos(): - break print() transcriptions = [] diff --git a/test/c_api_tests.cpp b/test/c_api_tests.cpp index b40af6ddda..d8e38dc8f5 100644 --- a/test/c_api_tests.cpp +++ b/test/c_api_tests.cpp @@ -340,9 +340,6 @@ TEST(CAPITests, EndToEndPhiBatch) { while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->HitEOS()) { - break; - } } // Decode The Batch @@ -521,9 +518,6 @@ TEST(CAPITests, EndToEndPhi) { while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->HitEOS()) { - break; - } } // Decode The Batch @@ -563,9 +557,6 @@ TEST(CAPITests, EndToEndPhiEOSPAD) { while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->HitEOS()) { - break; - } } // Decode The Batch @@ -660,9 +651,6 @@ TEST(CAPITests, LoadModelFromMemory) { while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->HitEOS()) { - break; - } } // Decode The Batch @@ -740,9 +728,6 @@ TEST(CAPITests, GreedySearchGptFp32CAPI) { generator->AppendTokens(input_ids.data(), input_ids.size()); while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->HitEOS()) { - break; - } } // Verify outputs match expected outputs @@ -905,9 +890,6 @@ TEST(CAPITests, SetTerminate) { EXPECT_THROW({ while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->HitEOS()) { - break; - } } }, std::runtime_error); }; @@ -970,9 +952,6 @@ struct Phi2Test { while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->HitEOS()) { - break; - } } // Decode One at a time @@ -1132,9 +1111,6 @@ TEST(CAPITests, AdaptersTest) { while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->HitEOS()) { - break; - } } auto logits = generator->GetOutput("logits"); @@ -1158,9 +1134,6 @@ TEST(CAPITests, AdaptersTest) { while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->HitEOS()) { - break; - } } auto logits = generator->GetOutput("logits"); @@ -1212,9 +1185,6 @@ TEST(CAPITests, AdaptersTestMultipleAdapters) { while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->HitEOS()) { - break; - } } } @@ -1258,9 +1228,6 @@ TEST(CAPITests, BatchedRewindGptFp32CAPI) { generator->AppendTokens(input_ids.data(), input_ids.size()); while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->HitEOS()) { - break; - } } // Verify outputs match expected outputs @@ -1280,9 +1247,6 @@ TEST(CAPITests, BatchedRewindGptFp32CAPI) { generator->AppendTokens(input_ids.data(), input_ids.size()); while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->HitEOS()) { - break; - } } // Verify outputs match expected outputs @@ -1318,9 +1282,6 @@ TEST(CAPITests, RewindGptFp32CAPI) { generator->AppendTokens(input_ids.data(), input_ids.size()); while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->HitEOS()) { - break; - } } // Verify outputs match expected outputs @@ -1336,9 +1297,6 @@ TEST(CAPITests, RewindGptFp32CAPI) { while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->HitEOS()) { - break; - } } // Verify outputs match expected outputs @@ -1355,9 +1313,6 @@ TEST(CAPITests, RewindGptFp32CAPI) { generator->AppendTokens(next_ids.data(), next_ids.size()); while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->HitEOS()) { - break; - } } // Verify outputs match expected outputs @@ -1388,9 +1343,6 @@ TEST(CAPITests, SetGuidance) { generator->AppendTokenSequences(*input_sequences); while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->HitEOS()) { - break; - } } auto out_string = tokenizer->Decode(generator->GetSequenceData(0), generator->GetSequenceCount(0)); auto output = std::string(out_string).substr(std::string(input_string).size()); diff --git a/test/csharp/TestOnnxRuntimeGenAIAPI.cs b/test/csharp/TestOnnxRuntimeGenAIAPI.cs index 0496c39ac8..dec620d774 100644 --- a/test/csharp/TestOnnxRuntimeGenAIAPI.cs +++ b/test/csharp/TestOnnxRuntimeGenAIAPI.cs @@ -164,10 +164,6 @@ public void TestGreedySearch() while (!generator.IsDone()) { generator.GenerateNextToken(); - if (generator.HitEOS()) - { - break; - } } for (ulong i = 0; i < batchSize; i++) @@ -219,10 +215,6 @@ public void TestLoadModelFromMemory() while (!generator.IsDone()) { generator.GenerateNextToken(); - if (generator.HitEOS()) - { - break; - } } for (ulong i = 0; i < batchSize; i++) @@ -299,10 +291,6 @@ public void TestTopKSearch() while (!generator.IsDone()) { generator.GenerateNextToken(); - if (generator.HitEOS()) - { - break; - } } for (ulong i = 0; i < batchSize; i++) @@ -362,10 +350,6 @@ public void TestTopPSearch() while (!generator.IsDone()) { generator.GenerateNextToken(); - if (generator.HitEOS()) - { - break; - } } for (ulong i = 0; i < batchSize; i++) @@ -427,10 +411,6 @@ public void TestTopKTopPSearch() while (!generator.IsDone()) { generator.GenerateNextToken(); - if (generator.HitEOS()) - { - break; - } } for (ulong i = 0; i < batchSize; i++) @@ -638,10 +618,6 @@ public void TestPhi2() while (!generator.IsDone()) { generator.GenerateNextToken(); - if (generator.HitEOS()) - { - break; - } } for (ulong i = 0; i < batchSize; i++) @@ -748,10 +724,6 @@ public void TestAdapters() while (!generator.IsDone()) { generator.GenerateNextToken(); - if (generator.HitEOS()) - { - break; - } } using var logits = generator.GetOutput("logits"); @@ -780,10 +752,6 @@ public void TestAdapters() while (!generator.IsDone()) { generator.GenerateNextToken(); - if (generator.HitEOS()) - { - break; - } } using var logits = generator.GetOutput("logits"); if (_useCudaModel) diff --git a/test/model_tests.cpp b/test/model_tests.cpp index 7eecc499c0..a98c0caf8b 100644 --- a/test/model_tests.cpp +++ b/test/model_tests.cpp @@ -82,9 +82,6 @@ TEST(ModelTests, GreedySearchGptFp32) { while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->HitEOS()) { - break; - } } // Verify outputs match expected outputs @@ -126,9 +123,6 @@ TEST(ModelTests, BeamSearchGptFp32) { generator->AppendTokens(input_ids); while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->HitEOS()) { - break; - } } // Verify outputs match expected outputs @@ -163,9 +157,6 @@ void Test_GreedySearch_Gpt_Cuda(const char* model_path, const char* model_label) while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->HitEOS()) { - break; - } } // Verify outputs match expected outputs @@ -188,9 +179,6 @@ void Test_GreedySearch_Gpt_Cuda(const char* model_path, const char* model_label) while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->HitEOS()) { - break; - } } // Verify outputs match expected outputs @@ -203,9 +191,6 @@ void Test_GreedySearch_Gpt_Cuda(const char* model_path, const char* model_label) generator->AppendTokens(next_ids); while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->HitEOS()) { - break; - } } // Verify outputs match expected outputs @@ -249,9 +234,6 @@ void Test_BeamSearch_Gpt_Cuda(const char* model_path, const char* model_label) { generator->AppendTokens(input_ids); while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->HitEOS()) { - break; - } } // Verify outputs match expected outputs @@ -301,9 +283,6 @@ void Test_GreedySearch_Phi3_NvTensorRtRtx(const char* model_path, const char* mo while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->HitEOS()) { - break; - } } // Verify outputs match expected outputs @@ -354,9 +333,6 @@ void Test_OutOfPlaceKvCache_Phi3_NvTensorRtRtx(const char* model_path, const cha while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->HitEOS()) { - break; - } } auto sequence = generator->GetSequence(0); @@ -396,9 +372,6 @@ Print all primes between 1 and n generator->AppendTokens(tokens->Get(0)); while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->HitEOS()) { - break; - } } auto result = generator->GetSequence(0); @@ -430,9 +403,6 @@ Print all primes between 1 and n generator->AppendTokens(tokens->Get(0)); while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->HitEOS()) { - break; - } } auto result = generator->GetSequence(0); diff --git a/test/platform/apple/apple_package_test/ios_package_testUITests/ios_package_uitest_cpp_api.mm b/test/platform/apple/apple_package_test/ios_package_testUITests/ios_package_uitest_cpp_api.mm index 0123b6e7f6..40a46aaf03 100644 --- a/test/platform/apple/apple_package_test/ios_package_testUITests/ios_package_uitest_cpp_api.mm +++ b/test/platform/apple/apple_package_test/ios_package_testUITests/ios_package_uitest_cpp_api.mm @@ -56,9 +56,6 @@ - (void)testCppAPI_Basic { while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->HitEOS()) { - break; - } } const auto output_sequence_length = generator->GetSequenceCount(0); diff --git a/test/platform/apple/apple_package_test/macos_package_testUITests/macos_package_uitest_cpp_api.mm b/test/platform/apple/apple_package_test/macos_package_testUITests/macos_package_uitest_cpp_api.mm index 30afa758f0..af5d0046ec 100644 --- a/test/platform/apple/apple_package_test/macos_package_testUITests/macos_package_uitest_cpp_api.mm +++ b/test/platform/apple/apple_package_test/macos_package_testUITests/macos_package_uitest_cpp_api.mm @@ -56,9 +56,6 @@ - (void)testCppAPI_Basic { while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->HitEOS()) { - break; - } } const auto output_sequence_length = generator->GetSequenceCount(0); diff --git a/test/python/test_onnxruntime_genai_api.py b/test/python/test_onnxruntime_genai_api.py index 655f02c3b9..5d33a73aaa 100644 --- a/test/python/test_onnxruntime_genai_api.py +++ b/test/python/test_onnxruntime_genai_api.py @@ -155,8 +155,6 @@ def test_greedy_search(test_data_path, relative_model_path): generator.set_logits(logits) # twice just to be sure buffer is still valid generator.generate_next_token() - if generator.hit_eos(): - break expected_sequence = np.array( [ @@ -195,8 +193,6 @@ def test_rewind_cuda(test_data_path, relative_model_path): generator.append_tokens(np.array([[0, 0, 195, 731]], dtype=np.int32)) while not generator.is_done(): generator.generate_next_token() - if generator.hit_eos(): - break assert generator.get_sequence(0) is not None @@ -205,8 +201,6 @@ def test_rewind_cuda(test_data_path, relative_model_path): generator.append_tokens(np.array([[731, 731]], dtype=np.int32)) while not generator.is_done(): generator.generate_next_token() - if generator.hit_eos(): - break assert generator.get_sequence(0) is not None @@ -219,8 +213,6 @@ def test_rewind_cuda(test_data_path, relative_model_path): generator.append_tokens(np.array([[0, 0, 0, 52], [0, 0, 195, 731], [64, 65, 66, 67]], dtype=np.int32)) while not generator.is_done(): generator.generate_next_token() - if generator.hit_eos(): - break for i in range(batch_size): assert generator.get_sequence(i) is not None @@ -235,8 +227,6 @@ def test_rewind_cuda(test_data_path, relative_model_path): ) while not generator.is_done(): generator.generate_next_token() - if generator.hit_eos(): - break for i in range(batch_size): assert generator.get_sequence(i) is not None @@ -264,8 +254,6 @@ def test_rewind(test_data_path, relative_model_path): generator.append_tokens(np.array([[0, 0, 195, 731]], dtype=np.int32)) while not generator.is_done(): generator.generate_next_token() - if generator.hit_eos(): - break assert np.array_equal(expected_sequence, generator.get_sequence(0)) @@ -274,8 +262,6 @@ def test_rewind(test_data_path, relative_model_path): generator.append_tokens(np.array([[731, 731]], dtype=np.int32)) while not generator.is_done(): generator.generate_next_token() - if generator.hit_eos(): - break assert np.array_equal(expected_sequence, generator.get_sequence(0)) @@ -405,8 +391,6 @@ def test_batching(device, phi2_for): generator.append_tokens(tokenizer.encode_batch(prompts)) while not generator.is_done(): generator.generate_next_token() - if generator.hit_eos(): - break for i in range(len(prompts)): print(tokenizer.decode(generator.get_sequence(0))) @@ -436,8 +420,6 @@ def test_e2e(device, phi2_for): generator.append_tokens(tokenizer.encode_batch(prompts)) while not generator.is_done(): generator.generate_next_token() - if generator.hit_eos(): - break for i in range(len(prompts)): print(tokenizer.decode(generator.get_sequence(0))) @@ -471,8 +453,6 @@ def test_load_model_from_memory(device, wrapper_bytes_function, phi2_for): generator.append_tokens(tokenizer.encode_batch(prompts)) while not generator.is_done(): generator.generate_next_token() - if generator.hit_eos(): - break for i in range(len(prompts)): print(tokenizer.decode(generator.get_sequence(0))) @@ -650,8 +630,6 @@ def _split(onnx_model_path: os.PathLike, output_dir: os.PathLike): generator.append_tokens(tokenizer.encode_batch(prompts)) while not generator.is_done(): generator.generate_next_token() - if generator.hit_eos(): - break expected_output = [ "This is a test.\n # TOD import * doct proofingrad", @@ -850,8 +828,6 @@ def _export_adapter(adapter, adapter_file_name): generator.append_tokens(tokenizer.encode_batch(prompts)) while not generator.is_done(): generator.generate_next_token() - if generator.hit_eos(): - break @pytest.mark.parametrize("device", devices) @@ -937,8 +913,6 @@ def _prepare_model(test_data_path): while not generator.is_done(): generator.generate_next_token() - if generator.hit_eos(): - break @pytest.mark.parametrize("relative_model_path", [Path("audio-preprocessing")]) diff --git a/test/python/test_onnxruntime_genai_e2e.py b/test/python/test_onnxruntime_genai_e2e.py index 167658f703..f28d674e21 100644 --- a/test/python/test_onnxruntime_genai_e2e.py +++ b/test/python/test_onnxruntime_genai_e2e.py @@ -33,8 +33,6 @@ def run_model(model_path: str | bytes | os.PathLike): generator.append_tokens(sequences) while not generator.is_done(): generator.generate_next_token() - if generator.hit_eos(): - break for i in range(3): assert generator.get_sequence(i) is not None From 5cf30aea06897535af0be8ec3cbb8d485dfae594 Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Wed, 21 Jan 2026 11:46:30 +0000 Subject: [PATCH 08/38] Add C++ API for GetNextTokens and use in examples --- examples/c/src/model_chat.cpp | 3 +-- examples/c/src/model_qa.cpp | 3 +-- examples/c/src/model_vision.cpp | 4 +--- examples/c/src/phi4-mm.cpp | 4 +--- src/ort_genai.h | 9 +++++++++ 5 files changed, 13 insertions(+), 10 deletions(-) diff --git a/examples/c/src/model_chat.cpp b/examples/c/src/model_chat.cpp index a4bfac7dc9..4efe3a79bf 100644 --- a/examples/c/src/model_chat.cpp +++ b/examples/c/src/model_chat.cpp @@ -99,8 +99,7 @@ void CXX_API(const char* model_path, const char* execution_provider) { is_first_token = false; } - const auto num_tokens = generator->GetSequenceCount(0); - const auto new_token = generator->GetSequenceData(0)[num_tokens - 1]; + const auto new_token = generator->GetNextTokens()[0]; std::cout << tokenizer_stream->Decode(new_token) << std::flush; } } catch (const std::exception& e) { diff --git a/examples/c/src/model_qa.cpp b/examples/c/src/model_qa.cpp index 191b924fd8..b6b4a3c229 100644 --- a/examples/c/src/model_qa.cpp +++ b/examples/c/src/model_qa.cpp @@ -90,8 +90,7 @@ void CXX_API(const char* model_path, const char* execution_provider) { is_first_token = false; } - const auto num_tokens = generator->GetSequenceCount(0); - const auto new_token = generator->GetSequenceData(0)[num_tokens - 1]; + const auto new_token = generator->GetNextTokens()[0]; std::cout << tokenizer_stream->Decode(new_token) << std::flush; } } catch (const std::exception& e) { diff --git a/examples/c/src/model_vision.cpp b/examples/c/src/model_vision.cpp index 9d7604cfd2..b902ecee2c 100644 --- a/examples/c/src/model_vision.cpp +++ b/examples/c/src/model_vision.cpp @@ -92,9 +92,7 @@ void CXX_API(const char* model_path, const char* execution_provider) { while (!generator->IsDone()) { generator->GenerateNextToken(); - - const auto num_tokens = generator->GetSequenceCount(0); - const auto new_token = generator->GetSequenceData(0)[num_tokens - 1]; + const auto new_token = generator->GetNextTokens()[0]; std::cout << stream->Decode(new_token) << std::flush; } diff --git a/examples/c/src/phi4-mm.cpp b/examples/c/src/phi4-mm.cpp index db5ef5719b..e1deb781d4 100644 --- a/examples/c/src/phi4-mm.cpp +++ b/examples/c/src/phi4-mm.cpp @@ -103,9 +103,7 @@ void CXX_API(const char* model_path, const char* execution_provider) { while (!generator->IsDone()) { generator->GenerateNextToken(); - - const auto num_tokens = generator->GetSequenceCount(0); - const auto new_token = generator->GetSequenceData(0)[num_tokens - 1]; + const auto new_token = generator->GetNextTokens()[0]; std::cout << stream->Decode(new_token) << std::flush; } diff --git a/src/ort_genai.h b/src/ort_genai.h index b0c8ce379d..7c60e94a31 100644 --- a/src/ort_genai.h +++ b/src/ort_genai.h @@ -493,6 +493,15 @@ struct OgaGenerator : OgaAbstract { OgaCheckResult(OgaGenerator_GetNextTokens(this, &out, &out_count)); return {out, out_count}; } +#else + std::vector GetNextTokens() { + std::vector next_tokens; + const int32_t* out; + size_t out_count; + OgaCheckResult(OgaGenerator_GetNextTokens(this, &out, &out_count)); + next_tokens.assign(out, out + out_count); + return next_tokens; + } #endif void RewindTo(size_t new_length) { From f96fa851e191576252cf18e30298ea68b9bfa53e Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Wed, 21 Jan 2026 21:36:34 +0000 Subject: [PATCH 09/38] Introduce TokenCount API instead --- src/csharp/Generator.cs | 16 +++---- src/csharp/NativeMethods.cs | 12 ++--- src/cuda/search_cuda.cpp | 9 +--- src/cuda/search_cuda.cu | 7 ++- src/cuda/search_cuda.cuh | 4 +- src/cuda/search_cuda.h | 10 ----- src/generators.cpp | 12 ++--- src/generators.h | 3 +- .../java/ai/onnxruntime/genai/Generator.java | 45 +++++++------------ .../native/ai_onnxruntime_genai_Generator.cpp | 17 ++++--- src/objectivec/include/ort_genai_objc.h | 21 +++------ src/objectivec/oga_generator.mm | 21 +++------ src/ort_genai.h | 14 +++--- src/ort_genai_c.cpp | 15 +++---- src/ort_genai_c.h | 32 +++++++------ src/python/python.cpp | 11 +---- src/search.cpp | 5 --- src/search.h | 6 --- 18 files changed, 91 insertions(+), 169 deletions(-) diff --git a/src/csharp/Generator.cs b/src/csharp/Generator.cs index 68282ba5e5..30f9ffbdb3 100644 --- a/src/csharp/Generator.cs +++ b/src/csharp/Generator.cs @@ -20,16 +20,6 @@ public bool IsDone() return NativeMethods.OgaGenerator_IsDone(_generatorHandle) != 0; } - public bool HitEOS() - { - return NativeMethods.OgaGenerator_HitEOS(_generatorHandle) != 0; - } - - public bool HitMaxLength() - { - return NativeMethods.OgaGenerator_HitMaxLength(_generatorHandle) != 0; - } - public void SetModelInput(string name, Tensor value) { Result.VerifySuccess(NativeMethods.OgaGenerator_SetModelInput(_generatorHandle, StringUtils.ToUtf8(name), value.Handle)); @@ -56,6 +46,12 @@ public void AppendTokenSequences(Sequences sequences) Result.VerifySuccess(NativeMethods.OgaGenerator_AppendTokenSequences(_generatorHandle, sequences.Handle)); } + public int TokenCount() + { + Result.VerifySuccess(NativeMethods.OgaGenerator_TokenCount(_generatorHandle, out int count)); + return count; + } + public void GenerateNextToken() { Result.VerifySuccess(NativeMethods.OgaGenerator_GenerateNextToken(_generatorHandle)); diff --git a/src/csharp/NativeMethods.cs b/src/csharp/NativeMethods.cs index ca9c5fb261..2093289228 100644 --- a/src/csharp/NativeMethods.cs +++ b/src/csharp/NativeMethods.cs @@ -125,14 +125,6 @@ internal class NativeLib [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)] public static extern byte OgaGenerator_IsDone(IntPtr /* const OgaGenerator* */ generator); - // This function is used to check if the generator has hit the EOS token id after generating all sequences. - [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)] - public static extern byte OgaGenerator_HitEOS(IntPtr /* const OgaGenerator* */ generator); - - // This function is used to check if the generator has hit the max length after generating all sequences. - [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)] - public static extern byte OgaGenerator_HitMaxLength(IntPtr /* const OgaGenerator* */ generator); - [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)] public static extern IntPtr /* OgaResult* */ OgaGenerator_GetNextTokens(IntPtr /* const OgaGenerator* */ generator, out IntPtr /* const int32_t** */ outTokenIds, @@ -162,6 +154,10 @@ internal class NativeLib public static extern IntPtr /* OgaResult* */ OgaGenerator_AppendTokenSequences(IntPtr /* OgaGenerator* */ generator, IntPtr /* const OgaSequences* */ sequences); + // This function is used to get the number of tokens in the generator. + [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)] + public static extern IntPtr /* OgaResult* */ OgaGenerator_TokenCount(IntPtr /* OgaGenerator* */ generator, + int* /* int32_t* */ count); // This function is used to rewind the generator to the given newLength. [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)] diff --git a/src/cuda/search_cuda.cpp b/src/cuda/search_cuda.cpp index f4ebfd8820..36fa4a6ac9 100644 --- a/src/cuda/search_cuda.cpp +++ b/src/cuda/search_cuda.cpp @@ -26,8 +26,6 @@ Search_Cuda::Search_Cuda(const GeneratorParams& params) eos_seen_buffer_ = CudaMallocArray(batch_beam_size, &eos_seen_); done_cpu_ = CudaMallocHostArray(1); - hit_eos_cpu_ = CudaMallocHostArray(1); - hit_max_length_cpu_ = CudaMallocHostArray(1); eos_token_ids_ = params.p_device->Allocate(params.config.model.eos_token_id.size()); copy(std::span{params.config.model.eos_token_id}, eos_token_ids_.CpuSpan()); @@ -78,8 +76,6 @@ BeamSearch_Cuda::~BeamSearch_Cuda() = default; void Search_Cuda::ResetDone() { *done_cpu_ = false; - *hit_eos_cpu_ = false; - *hit_max_length_cpu_ = false; cudaMemsetAsync(eos_seen_.data(), 0, eos_seen_.size_bytes(), GetStream()); } @@ -171,7 +167,7 @@ void GreedySearch_Cuda::SampleTopKTopP(int k, float p, float temperature) { // Check for EOS assert(next_tokens_.size() == eos_seen_.size()); - cuda::Launch_CheckForEOSAndPad(next_tokens_.data(), static_cast(next_tokens_.size()), eos_seen_.data(), eos_token_ids_.Span().data(), static_cast(eos_token_ids_.Span().size()), params_->config.model.pad_token_id, done_cpu_.get(), hit_eos_cpu_.get(), GetStream()); + cuda::Launch_CheckForEOSAndPad(next_tokens_.data(), static_cast(next_tokens_.size()), eos_seen_.data(), eos_token_ids_.Span().data(), static_cast(eos_token_ids_.Span().size()), params_->config.model.pad_token_id, done_cpu_.get(), GetStream()); // Append tokens cudaStreamSynchronize(GetStream()); @@ -184,7 +180,6 @@ void GreedySearch_Cuda::SampleTopKTopP(int k, float p, float temperature) { if (GetLogItems().enabled && GetLogItems().hit_max_length) Log("hit_max_length", "greedy cuda hit"); *done_cpu_ = true; - *hit_max_length_cpu_ = true; } } @@ -195,7 +190,6 @@ bool BeamSearch_Cuda::IsDone() const { if (sequences_.GetSequenceLength() == params_->search.max_length) { if (GetLogItems().enabled && GetLogItems().hit_max_length) Log("hit_max_length", "beam cuda hit"); - *hit_max_length_cpu_ = true; return true; } return false; @@ -241,7 +235,6 @@ void GreedySearch_Cuda::AppendTokens(DeviceSpan& next_tokens) { if (GetLogItems().enabled && GetLogItems().hit_max_length) Log("hit_max_length", "greedy cuda hit"); *done_cpu_ = true; - *hit_max_length_cpu_ = true; return; } diff --git a/src/cuda/search_cuda.cu b/src/cuda/search_cuda.cu index 53c9690a65..e72689291b 100644 --- a/src/cuda/search_cuda.cu +++ b/src/cuda/search_cuda.cu @@ -74,7 +74,7 @@ struct ArgMaxDataImpl : ArgMaxData { cuda_unique_ptr> argmaxen_owner_; }; -__global__ void CheckForEOSAndPad(int32_t* next_tokens, int next_tokens_count, bool* eos_seen, const int* eos_token_ids, int eos_token_count, int pad_token_id, bool* done_cpu, bool* hit_eos_cpu) { +__global__ void CheckForEOSAndPad(int32_t* next_tokens, int next_tokens_count, bool* eos_seen, const int* eos_token_ids, int eos_token_count, int pad_token_id, bool* done_cpu) { for (int batch_id = 0; batch_id < next_tokens_count; ++batch_id) { // If EOS already met, pad if (eos_seen[batch_id]) { @@ -103,14 +103,13 @@ __global__ void CheckForEOSAndPad(int32_t* next_tokens, int next_tokens_count, b } if (batch_id == next_tokens_count) { *done_cpu = true; - *hit_eos_cpu = true; return; } } } -void Launch_CheckForEOSAndPad(int32_t* next_tokens, int next_tokens_count, bool* eos_seen, const int* eos_token_ids, int eos_token_count, int pad_token_id, bool* done_cpu, bool* hit_eos_cpu, cudaStream_t stream) { - CheckForEOSAndPad<<<1, 1, 0, stream>>>(next_tokens, next_tokens_count, eos_seen, eos_token_ids, eos_token_count, pad_token_id, done_cpu, hit_eos_cpu); +void Launch_CheckForEOSAndPad(int32_t* next_tokens, int next_tokens_count, bool* eos_seen, const int* eos_token_ids, int eos_token_count, int pad_token_id, bool* done_cpu, cudaStream_t stream) { + CheckForEOSAndPad<<<1, 1, 0, stream>>>(next_tokens, next_tokens_count, eos_seen, eos_token_ids, eos_token_count, pad_token_id, done_cpu); } __global__ void AddProbsKernel(float* log_probs, diff --git a/src/cuda/search_cuda.cuh b/src/cuda/search_cuda.cuh index d039bacae3..3aa5aa0fc5 100644 --- a/src/cuda/search_cuda.cuh +++ b/src/cuda/search_cuda.cuh @@ -9,8 +9,8 @@ struct ArgMaxData { virtual ~ArgMaxData() = default; }; -void Launch_CheckForEOSAndPad(int32_t* next_tokens, int next_tokens_count, bool* eos_seen, const int32_t* eos_token_ids, int eos_token_count, bool* done_cpu, bool* hit_eos_cpu, cudaStream_t stream); -void Launch_CheckForEOSAndPad(int32_t* next_tokens, int next_tokens_count, bool* eos_seen, const int32_t* eos_token_ids, int eos_token_count, int pad_token_id, bool* done_cpu, bool* hit_eos_cpu, cudaStream_t stream); +void Launch_CheckForEOSAndPad(int32_t* next_tokens, int next_tokens_count, bool* eos_seen, const int32_t* eos_token_ids, int eos_token_count, bool* done_cpu, cudaStream_t stream); +void Launch_CheckForEOSAndPad(int32_t* next_tokens, int next_tokens_count, bool* eos_seen, const int32_t* eos_token_ids, int eos_token_count, int pad_token_id, bool* done_cpu, cudaStream_t stream); void Launch_ExpandInputSequences(const std::span input_sequences, std::span sequences, int batch_size, int beam_size, int max_length, cudaStream_t stream); void Launch_AppendNextTokensToSequences(std::span next_tokens, std::span sequences, int batch_beam_size, int past_length, int max_length, cudaStream_t stream); void Launch_GetLastTokens(int32_t* next_tokens, const int32_t* sequences, int batch_beam_size, int sequence_length, int max_length, cudaStream_t stream); diff --git a/src/cuda/search_cuda.h b/src/cuda/search_cuda.h index a2cf44b95f..6bcc398ff8 100644 --- a/src/cuda/search_cuda.h +++ b/src/cuda/search_cuda.h @@ -19,14 +19,6 @@ struct Search_Cuda : Search { cudaStreamSynchronize(GetStream()); return *done_cpu_; } // TODO: Use an event - bool HitEOS() const { - cudaStreamSynchronize(GetStream()); - return *hit_eos_cpu_; - } - bool HitMaxLength() const { - cudaStreamSynchronize(GetStream()); - return *hit_max_length_cpu_; - } void ResetDone() override; DeviceSpan GetLogits() const override; @@ -48,8 +40,6 @@ struct Search_Cuda : Search { DeviceSpan next_token_scores_; // shape (beam_size*batch_size, vocab_size) cuda_host_unique_ptr done_cpu_; - cuda_host_unique_ptr hit_eos_cpu_; - cuda_host_unique_ptr hit_max_length_cpu_; }; struct GreedySearch_Cuda : Search_Cuda { diff --git a/src/generators.cpp b/src/generators.cpp index b48261781a..4f4beee2c9 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -454,6 +454,10 @@ void Generator::SetRuntimeOption(const char* key, const char* value) { state_->SetRunOption(key, value); } +int32_t Generator::TokenCount() { + return static_cast(search_->GetSequenceLength()); +} + bool Generator::IsDone() { ThrowErrorIfSessionTerminated(state_->session_terminated_); if (computed_logits_) { @@ -472,14 +476,6 @@ bool Generator::IsDone() { return is_done; } -bool Generator::HitEOS() const { - return search_->HitEOS(); -} - -bool Generator::HitMaxLength() const { - return search_->HitMaxLength(); -} - bool Generator::IsSessionTerminated() const { return state_->session_terminated_; } diff --git a/src/generators.h b/src/generators.h index a6bacb2573..c0144b0430 100644 --- a/src/generators.h +++ b/src/generators.h @@ -95,8 +95,7 @@ struct Generator : LeakChecked { Generator(const Model& model, const GeneratorParams& params); bool IsDone(); - bool HitEOS() const; - bool HitMaxLength() const; + int32_t TokenCount(); void AppendTokens(cpu_span input_ids); void GenerateNextToken(); void RewindToLength(size_t new_length); // Rewind state to new_length diff --git a/src/java/src/main/java/ai/onnxruntime/genai/Generator.java b/src/java/src/main/java/ai/onnxruntime/genai/Generator.java index 0fd0471af8..1cd83f9464 100644 --- a/src/java/src/main/java/ai/onnxruntime/genai/Generator.java +++ b/src/java/src/main/java/ai/onnxruntime/genai/Generator.java @@ -61,32 +61,6 @@ public boolean isDone() { return isDone(nativeHandle); } - /** - * Checks if the generation process ended because an EOS token id was hit. - * - * @return true if the EOS token was hit, false otherwise. - */ - public boolean hitEOS() { - if (nativeHandle == 0) { - throw new IllegalStateException("Instance has been freed and is invalid"); - } - - return hitEOS(nativeHandle); - } - - /** - * Checks if the generation process ended because the maximum length was hit. - * - * @return true if the maximum length was hit, false otherwise. - */ - public boolean hitMaxLength() { - if (nativeHandle == 0) { - throw new IllegalStateException("Instance has been freed and is invalid"); - } - - return hitMaxLength(nativeHandle); - } - /** * Add a Tensor as a model input. * @@ -156,6 +130,19 @@ public void appendTokenSequences(Sequences sequences) throws GenAIException { appendTokenSequences(nativeHandle, sequences.nativeHandle()); } + /** + * Returns the token count in the generator. + * + * @throws GenAIException If the call to the GenAI native API fails. + */ + public int tokenCount() throws GenAIException { + if (nativeHandle == 0) { + throw new IllegalStateException("Instance has been freed and is invalid"); + } + + tokenCount(nativeHandle); + } + /** * Rewinds the generator to the given length. This is useful when the user wants to rewind the * generator to a specific length and continue generating from that point. @@ -296,10 +283,6 @@ private native long createGenerator(long modelHandle, long generatorParamsHandle private native boolean isDone(long nativeHandle); - private native boolean hitEOS(long nativeHandle); - - private native boolean hitMaxLength(long nativeHandle); - private native void setModelInput(long nativeHandle, String inputName, long tensorHandle) throws GenAIException; @@ -310,6 +293,8 @@ private native void setModelInput(long nativeHandle, String inputName, long tens private native void appendTokenSequences(long nativeHandle, long sequencesHandle) throws GenAIException; + private native int tokenCount(long nativeHandle) throws GenAIException; + private native void rewindTo(long nativeHandle, long newLength) throws GenAIException; private native void generateNextTokenNative(long nativeHandle) throws GenAIException; diff --git a/src/java/src/main/native/ai_onnxruntime_genai_Generator.cpp b/src/java/src/main/native/ai_onnxruntime_genai_Generator.cpp index 5fc72083e2..9ad21a0ac1 100644 --- a/src/java/src/main/native/ai_onnxruntime_genai_Generator.cpp +++ b/src/java/src/main/native/ai_onnxruntime_genai_Generator.cpp @@ -67,19 +67,18 @@ Java_ai_onnxruntime_genai_Generator_appendTokens(JNIEnv* env, jobject thiz, jlon env->ReleaseIntArrayElements(token_ids, tokens, JNI_ABORT); } -JNIEXPORT jboolean JNICALL -Java_ai_onnxruntime_genai_Generator_isDone(JNIEnv* env, jobject thiz, jlong native_handle) { - return OgaGenerator_IsDone(reinterpret_cast(native_handle)); -} +JNIEXPORT jint JNICALL +Java_ai_onnxruntime_genai_Generator_tokenCount(JNIEnv* env, jobject thiz, jlong native_handle) { + OgaGenerator* generator = reinterpret_cast(native_handle); + int32_t count = 0; -JNIEXPORT jboolean JNICALL -Java_ai_onnxruntime_genai_Generator_hitEOS(JNIEnv* env, jobject thiz, jlong native_handle) { - return OgaGenerator_HitEOS(reinterpret_cast(native_handle)); + ThrowIfError(env, OgaGenerator_TokenCount(generator, &count)); + return static_cast(count); } JNIEXPORT jboolean JNICALL -Java_ai_onnxruntime_genai_Generator_hitMaxLength(JNIEnv* env, jobject thiz, jlong native_handle) { - return OgaGenerator_HitMaxLength(reinterpret_cast(native_handle)); +Java_ai_onnxruntime_genai_Generator_isDone(JNIEnv* env, jobject thiz, jlong native_handle) { + return OgaGenerator_IsDone(reinterpret_cast(native_handle)); } JNIEXPORT void JNICALL diff --git a/src/objectivec/include/ort_genai_objc.h b/src/objectivec/include/ort_genai_objc.h index efe8404255..e472d567c7 100644 --- a/src/objectivec/include/ort_genai_objc.h +++ b/src/objectivec/include/ort_genai_objc.h @@ -234,20 +234,6 @@ typedef NS_ENUM(NSInteger, OGAElementType) { */ - (BOOL)isDoneWithError:(NSError**)error __attribute__((swift_error(nonnull_error))); -/** - * Whether generation ended because an EOS token id was hit. - * @param error Optional error information set if an error occurs. - * @return The result, or false if an error occurs. - */ -- (BOOL)hitEOSWithError:(NSError**)error __attribute__((swift_error(nonnull_error))); - -/** - * Whether generation ended because the maximum length was hit. - * @param error Optional error information set if an error occurs. - * @return The result, or false if an error occurs. - */ -- (BOOL)hitMaxLengthWithError:(NSError**)error __attribute__((swift_error(nonnull_error))); - /** * Set input with NamedTensors type. * @param namedTensors The named tensors. @@ -280,6 +266,13 @@ typedef NS_ENUM(NSInteger, OGAElementType) { */ - (BOOL)appendTokens:(NSArray*)tokens error:(NSError**)error; +/** + * Get the number of tokens in the generator. + * @param error Optional error information set if an error occurs. + * @return The number of tokens in the generator or int32_t(-1) if an error occurs. + */ +- (int32_t)tokenCount:(NSError**)error; + /** * Rewinds the generator to the given length. * @param newLength The desired length in tokens after rewinding. diff --git a/src/objectivec/oga_generator.mm b/src/objectivec/oga_generator.mm index b4874594dc..9c6514fc6b 100644 --- a/src/objectivec/oga_generator.mm +++ b/src/objectivec/oga_generator.mm @@ -30,20 +30,6 @@ - (BOOL)isDoneWithError:(NSError**)error { OGA_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) } -- (BOOL)hitEOSWithError:(NSError**)error { - try { - return _generator->HitEOS(); - } - OGA_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) -} - -- (BOOL)hitMaxLengthWithError:(NSError**)error { - try { - return _generator->HitMaxLength(); - } - OGA_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) -} - - (BOOL)setInputs:(OGANamedTensors*)namedTensors error:(NSError**)error { try { _generator->SetInputs([namedTensors CXXAPIOgaNamedTensors]); @@ -81,6 +67,13 @@ - (BOOL)appendTokens:(NSArray*)tokens error:(NSError**)error { OGA_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) } +- (int32_t)tokenCount:(NSError**)error { + try { + return _generator->TokenCount(); + } + OGA_OBJC_API_IMPL_CATCH(error, int32_t(-1)) +} + - (BOOL)rewindTo:(size_t)newLength error:(NSError**)error { try { _generator->RewindTo(newLength); diff --git a/src/ort_genai.h b/src/ort_genai.h index 7c60e94a31..679e700b78 100644 --- a/src/ort_genai.h +++ b/src/ort_genai.h @@ -448,12 +448,8 @@ struct OgaGenerator : OgaAbstract { return OgaGenerator_IsDone(this); } - bool HitEOS() { - return OgaGenerator_HitEOS(this); - } - - bool HitMaxLength() { - return OgaGenerator_HitMaxLength(this); + bool IsSessionTerminated() const { + return OgaGenerator_IsSessionTerminated(this); } void SetModelInput(const char* name, OgaTensor& tensor) { @@ -478,8 +474,10 @@ struct OgaGenerator : OgaAbstract { } #endif - bool IsSessionTerminated() const { - return OgaGenerator_IsSessionTerminated(this); + int32_t TokenCount() const { + int32_t count; + OgaCheckResult(OgaGenerator_TokenCount(this, &count)); + return count; } void GenerateNextToken() { diff --git a/src/ort_genai_c.cpp b/src/ort_genai_c.cpp index bebfb0c1db..f5cb1f41e3 100644 --- a/src/ort_genai_c.cpp +++ b/src/ort_genai_c.cpp @@ -413,14 +413,6 @@ bool OGA_API_CALL OgaGenerator_IsDone(OgaGenerator* generator) { return generator->IsDone(); } -bool OGA_API_CALL OgaGenerator_HitEOS(OgaGenerator* generator) { - return generator->HitEOS(); -} - -bool OGA_API_CALL OgaGenerator_HitMaxLength(OgaGenerator* generator) { - return generator->HitMaxLength(); -} - bool OGA_API_CALL OgaGenerator_IsSessionTerminated(const OgaGenerator* generator) { return generator->IsSessionTerminated(); } @@ -465,6 +457,13 @@ OgaResult* OGA_API_CALL OgaGenerator_AppendTokens(OgaGenerator* generator, const OGA_CATCH } +OgaResult* OGA_API_CALL OgaGenerator_TokenCount(OgaGenerator* generator, int32_t* count) { + OGA_TRY + *count = generator->TokenCount(); + return nullptr; + OGA_CATCH +} + OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken(OgaGenerator* generator) { OGA_TRY generator->GenerateNextToken(); diff --git a/src/ort_genai_c.h b/src/ort_genai_c.h index e08a3891c2..9063ea83da 100644 --- a/src/ort_genai_c.h +++ b/src/ort_genai_c.h @@ -448,19 +448,10 @@ OGA_EXPORT void OGA_API_CALL OgaDestroyGenerator(OgaGenerator* generator); OGA_EXPORT bool OGA_API_CALL OgaGenerator_IsDone(OgaGenerator* generator); /** - * \brief Returns true if the generator is done because it hit the EOS token id after generating all the sequences. - * \param[in] generator The generator to check if it is done with generating all sequences. - * \return True if the generator has hit the EOS token id, false otherwise. - */ -OGA_EXPORT bool OGA_API_CALL OgaGenerator_HitEOS(OgaGenerator* generator); - -/** - * \brief Returns true if the generator is done because it hit the maximum length after generating all the sequences. - * \param[in] generator The generator to check if it is done with generating all sequences. - * \return True if the generator has hit the maximum length, false otherwise. + * \brief Returns true if the session has been terminated. + * \param[in] generator The generator to add the inputs to. + * \return True if the session has been terminated, false otherwise. */ -OGA_EXPORT bool OGA_API_CALL OgaGenerator_HitMaxLength(OgaGenerator* generator); - OGA_EXPORT bool OGA_API_CALL OgaGenerator_IsSessionTerminated(const OgaGenerator* generator); /** @@ -496,6 +487,14 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_AppendTokenSequences(OgaGenerato */ OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_AppendTokens(OgaGenerator* generator, const int32_t* input_ids, size_t input_ids_count); +/** + * \brief Returns the number of tokens in the generator + * \param[in] generator The generator containing the appended tokens. + * \param[out] count The number of tokens that have been added. + * \return OgaResult containing the error message if the getting of the number of tokens failed. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_TokenCount(OgaGenerator* generator, int32_t* count); + /** * \brief Computes the logits from the model based on the input ids and the past state. The computed logits are stored in the generator. * \param[in] generator The generator to compute the logits for. @@ -512,6 +511,13 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken(OgaGenerator* */ OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_GetNextTokens(const OgaGenerator* generator, const int32_t** out, size_t* out_count); +/** + * \brief Set a runtime option's name and value. + * \param[in] generator The generator to rewind to the given length. + * \param[in] key The runtime option's name + * \param[in] value The runtime option's value + * \return OgaResult containing the error message if setting the runtime option failed. + */ OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_SetRuntimeOption(OgaGenerator* generator, const char* key, const char* value); /** @@ -636,7 +642,7 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerGetBosTokenId(const OgaTokenizer* OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerGetEosTokenIds(const OgaTokenizer* tokenizer, const int32_t** eos_token_ids, size_t* token_count); /** - * Return the int representation of the BOS token + * Return the int representation of the PAD token */ OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerGetPadTokenId(const OgaTokenizer* tokenizer, int32_t* token_id); diff --git a/src/python/python.cpp b/src/python/python.cpp index 3ce581bd1d..db9591c8da 100644 --- a/src/python/python.cpp +++ b/src/python/python.cpp @@ -260,14 +260,6 @@ struct PyGenerator { return generator_->IsDone(); } - bool HitEOS() { - return generator_->HitEOS(); - } - - bool HitMaxLength() { - return generator_->HitMaxLength(); - } - void SetActiveAdapter(OgaAdapters& adapters, const std::string& adapter_name) { generator_->SetActiveAdapter(adapters, adapter_name.c_str()); } @@ -463,14 +455,13 @@ PYBIND11_MODULE(onnxruntime_genai, m) { pybind11::class_(m, "Generator") .def(pybind11::init()) .def("is_done", &PyGenerator::IsDone) - .def("hit_eos", &PyGenerator::HitEOS) - .def("hit_max_length", &PyGenerator::HitMaxLength) .def("get_input", &PyGenerator::GetInput) .def("get_output", &PyGenerator::GetOutput) .def("set_inputs", &PyGenerator::SetInputs) .def("set_model_input", &PyGenerator::SetModelInput) .def("append_tokens", pybind11::overload_cast&>(&PyGenerator::AppendTokens)) .def("append_tokens", pybind11::overload_cast(&PyGenerator::AppendTokens)) + .def("token_count", &OgaGenerator::TokenCount) .def("get_logits", &PyGenerator::GetLogits) .def("set_logits", &PyGenerator::SetLogits) .def("generate_next_token", &PyGenerator::GenerateNextToken) diff --git a/src/search.cpp b/src/search.cpp index 24d5d97b47..3e9d545c56 100644 --- a/src/search.cpp +++ b/src/search.cpp @@ -52,8 +52,6 @@ BeamSearch_Cpu::~BeamSearch_Cpu() = default; void Search_Cpu::ResetDone() { // Reset done count/state done_ = false; - hit_eos_ = false; - hit_max_length_ = false; } void GreedySearch_Cpu::ResetDone() { @@ -311,7 +309,6 @@ void GreedySearch_Cpu::SetNextToken(size_t batch_id, int32_t token) { Log("hit_eos", "EOS seen on batch " + std::to_string(batch_id)); if (--not_done_count_ == 0) { done_ = true; - hit_eos_ = true; } } } @@ -333,7 +330,6 @@ void GreedySearch_Cpu::AppendNextTokensToSequences() { if (g_log.enabled && g_log.hit_max_length) Log("hit_max_length", "greedy cpu hit"); done_ = true; - hit_max_length_ = true; } } @@ -419,7 +415,6 @@ void BeamSearch_Cpu::AppendNextTokensToSequences() { if (g_log.enabled && g_log.hit_max_length) Log("hit_max_length", "beam cpu hit"); done_ = true; - hit_max_length_ = true; } } diff --git a/src/search.h b/src/search.h index f81ac93e85..98c0e91a73 100644 --- a/src/search.h +++ b/src/search.h @@ -19,8 +19,6 @@ struct Search : LeakChecked { virtual DeviceSpan GetLogits() const = 0; virtual void SetLogits(DeviceSpan logits) = 0; virtual bool IsDone() const = 0; - virtual bool HitEOS() const = 0; - virtual bool HitMaxLength() const = 0; virtual void ResetDone() = 0; virtual void SelectTop() = 0; @@ -47,8 +45,6 @@ struct Search_Cpu : Search { DeviceSpan GetSequenceLengths() override { return sequence_lengths_; } bool IsDone() const override { return done_; } - bool HitEOS() const override { return hit_eos_; } - bool HitMaxLength() const override { return hit_max_length_; } void ResetDone() override; DeviceSpan GetLogits() const override; void SetLogits(DeviceSpan logits) override; @@ -67,8 +63,6 @@ struct Search_Cpu : Search { DeviceSpan next_token_scores_; // shape (beam_size*batch_size, vocab_size) bool done_{}; - bool hit_eos_{}; - bool hit_max_length_{}; }; struct GreedySearch_Cpu : Search_Cpu { From 06f67036df24e369d65475634c2245fe5aff3ea4 Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Thu, 22 Jan 2026 01:24:21 +0000 Subject: [PATCH 10/38] Add GetSearchNumber and GetSearchBool APIs --- src/csharp/GeneratorParams.cs | 15 ++++-- src/csharp/NativeMethods.cs | 23 ++++++--- src/generators.cpp | 46 +++++++++++++++++ src/generators.h | 4 ++ .../ai/onnxruntime/genai/GeneratorParams.java | 36 ++++++++++++- .../ai_onnxruntime_genai_GeneratorParams.cpp | 25 ++++++++-- src/objectivec/error_utils.h | 3 +- src/objectivec/include/ort_genai_objc.h | 38 ++++++++++++++ src/objectivec/oga_generator_params.mm | 14 ++++++ src/objectivec/oga_tokenizer.mm | 29 +++++++++++ src/ort_genai.h | 16 ++++-- src/ort_genai_c.cpp | 15 ++++-- src/ort_genai_c.h | 50 +++++++++++++++++-- src/python/python.cpp | 26 ++++++++-- 14 files changed, 307 insertions(+), 33 deletions(-) diff --git a/src/csharp/GeneratorParams.cs b/src/csharp/GeneratorParams.cs index cc23eba0f9..4eb7e01fa7 100644 --- a/src/csharp/GeneratorParams.cs +++ b/src/csharp/GeneratorParams.cs @@ -29,14 +29,21 @@ public void SetSearchOption(string searchOption, bool value) Result.VerifySuccess(NativeMethods.OgaGeneratorParamsSetSearchBool(_generatorParamsHandle, StringUtils.ToUtf8(searchOption), value)); } - public void TryGraphCaptureWithMaxBatchSize(int maxBatchSize) + public void SetGuidance(string type, string data, bool enableFFTokens = false) { - Console.WriteLine("TryGraphCaptureWithMaxBatchSize is deprecated and will be removed in a future release."); + Result.VerifySuccess(NativeMethods.OgaGeneratorParamsSetGuidance(_generatorParamsHandle, StringUtils.ToUtf8(type), StringUtils.ToUtf8(data), enableFFTokens)); } - public void SetGuidance(string type, string data, bool enableFFTokens = false) + public double GetSearchNumber(string searchOption) { - Result.VerifySuccess(NativeMethods.OgaGeneratorParamsSetGuidance(_generatorParamsHandle, StringUtils.ToUtf8(type), StringUtils.ToUtf8(data), enableFFTokens)); + Result.VerifySuccess(NativeMethods.OgaGeneratorParamsGetSearchNumber(_generatorParamsHandle, StringUtils.ToUtf8(searchOption), out double value)); + return value; + } + + public bool GetSearchBool(string searchOption) + { + Result.VerifySuccess(NativeMethods.OgaGeneratorParamsGetSearchBool(_generatorParamsHandle, StringUtils.ToUtf8(searchOption), out bool value)); + return value; } ~GeneratorParams() diff --git a/src/csharp/NativeMethods.cs b/src/csharp/NativeMethods.cs index 2093289228..212d3e5f58 100644 --- a/src/csharp/NativeMethods.cs +++ b/src/csharp/NativeMethods.cs @@ -106,18 +106,27 @@ internal class NativeLib public static extern IntPtr /* OgaResult* */ OgaGeneratorParamsSetSearchBool(IntPtr /* OgaGeneratorParams* */ generatorParams, byte[] /* const char* */ searchOption, bool value); - - [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)] - public static extern IntPtr /* OgaResult* */ OgaCreateGenerator(IntPtr /* const OgaModel* */ model, - IntPtr /* const OgaGeneratorParams* */ generatorParams, - out IntPtr /* OgaGenerator** */ generator); - [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)] public static extern IntPtr /* OgaResult* */ OgaGeneratorParamsSetGuidance(IntPtr /* OgaGeneratorParams* */ generatorParams, byte[] /* const char* */ type, byte[] /* const char* */ data, bool /* boolean */ enable_ff_tokens); + [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)] + public static extern IntPtr /* OgaResult* */ OgaGeneratorParamsGetSearchNumber(IntPtr /* OgaGeneratorParams* */ generatorParams, + byte[] /* const char* */ searchOption, + out double /* const double* */ value); + + [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)] + public static extern IntPtr /* OgaResult* */ OgaGeneratorParamsGetSearchBool(IntPtr /* OgaGeneratorParams* */ generatorParams, + byte[] /* const char* */ searchOption, + out bool /* const bool* */ value); + + [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)] + public static extern IntPtr /* OgaResult* */ OgaCreateGenerator(IntPtr /* const OgaModel* */ model, + IntPtr /* const OgaGeneratorParams* */ generatorParams, + out IntPtr /* OgaGenerator** */ generator); + [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)] public static extern void OgaDestroyGenerator(IntPtr /* OgaGenerator* */ generator); @@ -157,7 +166,7 @@ internal class NativeLib // This function is used to get the number of tokens in the generator. [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)] public static extern IntPtr /* OgaResult* */ OgaGenerator_TokenCount(IntPtr /* OgaGenerator* */ generator, - int* /* int32_t* */ count); + out int /* int32_t* */ count); // This function is used to rewind the generator to the given newLength. [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)] diff --git a/src/generators.cpp b/src/generators.cpp index 4f4beee2c9..9abedd4574 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -288,6 +288,52 @@ bool GeneratorParams::IsPastPresentShareBufferEnabled(const std::string& model_t (search.num_beams == 1 || model_type == "whisper"); } +double GeneratorParams::GetSearchNumber(std::string_view name) const { + if (name == "batch_size") { + return static_cast(search.batch_size); + } else if (name == "chunk_size" && search.chunk_size.has_value()) { + return static_cast(search.chunk_size); + } else if (name == "diversity_penalty") { + return search.diversity_penalty; + } else if (name == "length_penalty") { + return search.length_penalty; + } else if (name == "max_length") { + return static_cast(search.max_length); + } else if (name == "min_length") { + return static_cast(search.min_length); + } else if (name == "no_repeat_ngram_size") { + return static_cast(search.no_repeat_ngram_size); + } else if (name == "num_beams") { + return static_cast(search.num_beams); + } else if (name == "num_return_sequences") { + return static_cast(search.num_return_sequences); + } else if (name == "random_seed") { + return static_cast(search.random_seed); + } else if (name == "repetition_penalty") { + return search.repetition_penalty; + } else if (name == "temperature") { + return search.temperature; + } else if (name == "top_k") { + return static_cast(search.top_k); + } else if (name == "top_p") { + return search.top_p; + } else { + throw std::runtime_error("Invalid name for GetSearchNumber."); + } +} + +bool GeneratorParams::GetSearchBool(std::string_view name) const { + if (name == "do_sample") { + return search.do_sample; + } else if (name == "early_stopping") { + return search.early_stopping; + } else if (name == "past_present_share_buffer") { + return search.past_present_share_buffer; + } else { + throw std::runtime_error("Invalid name for GetSearchBool."); + } +} + std::unique_ptr CreateGenerator(const Model& model, const GeneratorParams& params) { return std::make_unique(model, params); } diff --git a/src/generators.h b/src/generators.h index c0144b0430..89806d8aec 100644 --- a/src/generators.h +++ b/src/generators.h @@ -74,6 +74,10 @@ struct GeneratorParams : std::enable_shared_from_this, LeakChec const Config& config; // The model outlives the GeneratorParams Config::Search search{config.search}; // Copy of the search parameters from the config + // Query the params to get the value set for a param + double GetSearchNumber(std::string_view name) const; + bool GetSearchBool(std::string_view name) const; + int max_batch_size{0}; bool use_graph_capture{}; bool use_multi_profile{}; diff --git a/src/java/src/main/java/ai/onnxruntime/genai/GeneratorParams.java b/src/java/src/main/java/ai/onnxruntime/genai/GeneratorParams.java index 7bf8306f4a..ce043234c8 100644 --- a/src/java/src/main/java/ai/onnxruntime/genai/GeneratorParams.java +++ b/src/java/src/main/java/ai/onnxruntime/genai/GeneratorParams.java @@ -29,7 +29,7 @@ public GeneratorParams(Model model) throws GenAIException { } /** - * Set seach option with double value. + * Set search option with double value. * * @param optionName The option name. * @param value The option value. @@ -58,6 +58,36 @@ public void setSearchOption(String optionName, boolean value) throws GenAIExcept setSearchOptionBool(nativeHandle, optionName, value); } + /** + * Get search option with numerical value. + * + * @param optionName The option name. + * @return The option value. + * @throws GenAIException If the call to the GenAI native API fails. + */ + public void getSearchNumber(String optionName) throws GenAIException { + if (nativeHandle == 0) { + throw new IllegalStateException("Instance has been freed and is invalid"); + } + + getSearchNumber(nativeHandle, optionName); + } + + /** + * Get search option with boolean value. + * + * @param optionName The option name. + * @return The option value. + * @throws GenAIException If the call to the GenAI native API fails. + */ + public void getSearchBool(String optionName) throws GenAIException { + if (nativeHandle == 0) { + throw new IllegalStateException("Instance has been freed and is invalid"); + } + + getSearchBool(nativeHandle, optionName); + } + @Override public void close() { if (nativeHandle != 0) { @@ -87,4 +117,8 @@ private native void setSearchOptionNumber(long nativeHandle, String optionName, private native void setSearchOptionBool(long nativeHandle, String optionName, boolean value) throws GenAIException; + + private native void getSearchNumber(long nativeHandle, String optionName) throws GenAIException; + + private native void getSearchBool(long nativeHandle, String optionName) throws GenAIException; } diff --git a/src/java/src/main/native/ai_onnxruntime_genai_GeneratorParams.cpp b/src/java/src/main/native/ai_onnxruntime_genai_GeneratorParams.cpp index 3664c0e3e8..df29a3e7a1 100644 --- a/src/java/src/main/native/ai_onnxruntime_genai_GeneratorParams.cpp +++ b/src/java/src/main/native/ai_onnxruntime_genai_GeneratorParams.cpp @@ -35,11 +35,30 @@ Java_ai_onnxruntime_genai_GeneratorParams_setSearchOptionNumber(JNIEnv* env, job ThrowIfError(env, OgaGeneratorParamsSetSearchNumber(generator_params, name, value)); } -JNIEXPORT void JNICALL -Java_ai_onnxruntime_genai_GeneratorParams_setSearchOptionBool(JNIEnv* env, jobject thiz, jlong native_handle, - jstring option_name, jboolean value) { +JNIEXPORT jdouble JNICALL +Java_ai_onnxruntime_genai_GeneratorParams_setSearchOptionBool(JNIEnv* env, jobject thiz, jlong native_handle, jstring option_name) { OgaGeneratorParams* generator_params = reinterpret_cast(native_handle); CString name{env, option_name}; ThrowIfError(env, OgaGeneratorParamsSetSearchBool(generator_params, name, value)); } + +JNIEXPORT jdouble JNICALL +Java_ai_onnxruntime_genai_GeneratorParams_getSearchNumber(JNIEnv* env, jobject thiz, jlong native_handle, jstring option_name) { + OgaGeneratorParams* generator_params = reinterpret_cast(native_handle); + CString name{env, option_name}; + double value = 0.0; + + ThrowIfError(env, OgaGeneratorParamsGetSearchNumber(generator_params, name, &value)); + return static_cast(value); +} + +JNIEXPORT jboolean JNICALL +Java_ai_onnxruntime_genai_GeneratorParams_getSearchBool(JNIEnv* env, jobject thiz, jlong native_handle, jstring option_name) { + OgaGeneratorParams* generator_params = reinterpret_cast(native_handle); + CString name{env, option_name}; + bool value = false; + + ThrowIfError(env, OgaGeneratorParamsSetSearchBool(generator_params, name, &value)); + return static_cast(value); +} diff --git a/src/objectivec/error_utils.h b/src/objectivec/error_utils.h index 8c73faa971..9359a23cf5 100644 --- a/src/objectivec/error_utils.h +++ b/src/objectivec/error_utils.h @@ -23,7 +23,8 @@ void OGASaveExceptionToError(const std::exception& e, NSError** error); } #define OGA_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) OGA_OBJC_API_IMPL_CATCH(error, NO) - +#define OGA_OBJC_API_IMPL_CATCH_RETURNING_DOUBLE (error) OGA_OBJC_API_IMPL_CATCH(error, 0.0) +#define OGA_OBJC_API_IMPL_CATCH_RETURNING_INT (error) OGA_OBJC_API_IMPL_CATCH(error, 0) #define OGA_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) OGA_OBJC_API_IMPL_CATCH(error, nil) NS_ASSUME_NONNULL_END diff --git a/src/objectivec/include/ort_genai_objc.h b/src/objectivec/include/ort_genai_objc.h index e472d567c7..7d615d5172 100644 --- a/src/objectivec/include/ort_genai_objc.h +++ b/src/objectivec/include/ort_genai_objc.h @@ -81,6 +81,27 @@ typedef NS_ENUM(NSInteger, OGAElementType) { - (nullable instancetype)initWithModel:(OGAModel*)model error:(NSError**)error NS_DESIGNATED_INITIALIZER; +/** + * Return the int representation of the BOS token. + * + * @return The BOS token id + */ +- (int32_t)getBosTokenId:(NSError**)error; + +/** + * Return the int representations of the array of EOS tokens. + * + * @return The array of EOS token ids + */ +- (nullable NSArray*)getEosTokenIds:(NSError**)error; + +/** + * Return the int representation of the PAD token. + * + * @return The PAD token id + */ +- (int32_t)getPadTokenId:(NSError**)error; + /** * Encode text to sequences * @@ -208,6 +229,23 @@ typedef NS_ENUM(NSInteger, OGAElementType) { - (BOOL)setSearchOption:(NSString*)key boolValue:(BOOL)value error:(NSError**)error; + +/** + * Get numerical value of option. + * @param key The option key. + * @param error Optional error information set if an error occurs. + * @return The option value. + */ +- (double)getSearchNumber:(NSString*)key + error:(NSError**)error; +/** + * Get boolean value of option. + * @param key The option key. + * @param error Optional error information set if an error occurs. + * @return The option value. + */ +- (BOOL)getSearchBool:(NSString*)key + error:(NSError**)error; @end /** diff --git a/src/objectivec/oga_generator_params.mm b/src/objectivec/oga_generator_params.mm index 1ae8719d41..25ecfd0386 100644 --- a/src/objectivec/oga_generator_params.mm +++ b/src/objectivec/oga_generator_params.mm @@ -37,6 +37,20 @@ - (BOOL)setSearchOption:(NSString*)key boolValue:(BOOL)value error:(NSError**)er OGA_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) } +- (double)getSearchNumber:(NSString*)key error:(NSError**)error { + try { + return _generatorParams->GetSearchNumber([key UTF8String]); + } + OGA_OBJC_API_IMPL_CATCH_RETURNING_DOUBLE(error) +} + +- (BOOL)getSearchBool:(NSString*)key error:(NSError**)error { + try { + return _generatorParams->GetSearchBool([key UTF8String]); + } + OGA_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) +} + - (OgaGeneratorParams&)CXXAPIOgaGeneratorParams { return *(_generatorParams.get()); } diff --git a/src/objectivec/oga_tokenizer.mm b/src/objectivec/oga_tokenizer.mm index c961faabc8..10f67f7b13 100644 --- a/src/objectivec/oga_tokenizer.mm +++ b/src/objectivec/oga_tokenizer.mm @@ -21,6 +21,35 @@ - (nullable instancetype)initWithModel:(OGAModel*)model error:(NSError**)error { OGA_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) } +- (int32_t)getBosTokenId:(NSError**)error { + try { + return _tokenizer->GetBosTokenId(); + } + OGA_OBJC_API_IMPL_CATCH_RETURNING_INT(error) +} + +- (nullable NSArray*)getEosTokenIds:(NSError**)error { + try { + std::vector eos_ids = _tokenizer->GetEosTokenIds(); + NSMutableArray* result = [NSMutableArray arrayWithCapacity:eos_ids.size()]; + if (!result) { + return nil; + } + for (int32_t eos_id : eos_ids) { + [result addObject:@(eos_id)]; + } + return result; + } + OGA_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) +} + +- (int32_t)getPadTokenId:(NSError**)error { + try { + return _tokenizer->GetPadTokenId(); + } + OGA_OBJC_API_IMPL_CATCH_RETURNING_INT(error) +} + - (nullable OGASequences*)encode:(NSString*)str error:(NSError**)error { OGASequences* sequences = [[OGASequences alloc] initWithError:error]; if (!sequences) { diff --git a/src/ort_genai.h b/src/ort_genai.h index 679e700b78..73f8d64344 100644 --- a/src/ort_genai.h +++ b/src/ort_genai.h @@ -426,14 +426,22 @@ struct OgaGeneratorParams : OgaAbstract { OgaCheckResult(OgaGeneratorParamsSetSearchBool(this, name, value)); } - void TryGraphCaptureWithMaxBatchSize(int /*max_batch_size*/) { - printf("TryGraphCaptureWithMaxBatchSize is deprecated and will be removed in a future release\n"); - } - void SetGuidance(const char* type, const char* data, bool enable_ff_tokens = false) { OgaCheckResult(OgaGeneratorParamsSetGuidance(this, type, data, enable_ff_tokens)); } + double GetSearchNumber(const char* name) const { + double value; + OgaCheckResult(OgaGeneratorParamsGetSearchNumber(this, name, &value)); + return value; + } + + bool GetSearchBool(const char* name) const { + bool value; + OgaCheckResult(OgaGeneratorParamsGetSearchBool(this, name, &value)); + return value; + } + static void operator delete(void* p) { OgaDestroyGeneratorParams(reinterpret_cast(p)); } }; diff --git a/src/ort_genai_c.cpp b/src/ort_genai_c.cpp index f5cb1f41e3..c2b652374a 100644 --- a/src/ort_genai_c.cpp +++ b/src/ort_genai_c.cpp @@ -388,16 +388,23 @@ OgaResult* OGA_API_CALL OgaGeneratorParamsSetSearchBool(OgaGeneratorParams* para OGA_CATCH } -OgaResult* OGA_API_CALL OgaGeneratorParamsTryGraphCaptureWithMaxBatchSize(OgaGeneratorParams* params, int32_t max_batch_size) { +OgaResult* OGA_API_CALL OgaGeneratorParamsSetGuidance(OgaGeneratorParams* params, const char* type, const char* data, bool enable_ff_tokens) { OGA_TRY - printf("TryGraphCaptureWithMaxBatchSize is deprecated and will be removed in a future release\n"); + params->SetGuidance(type, data, enable_ff_tokens); return nullptr; OGA_CATCH } -OgaResult* OGA_API_CALL OgaGeneratorParamsSetGuidance(OgaGeneratorParams* params, const char* type, const char* data, bool enable_ff_tokens) { +OgaResult* OGA_API_CALL OgaGeneratorParamsGetSearchNumber(OgaGeneratorParams* params, const char* name, double* value) { OGA_TRY - params->SetGuidance(type, data, enable_ff_tokens); + *value = params->GetIntValue(name); + return nullptr; + OGA_CATCH +} + +OgaResult* OGA_API_CALL OgaGeneratorParamsGetSearchBool(OgaGeneratorParams* params, const char* name, bool* value) { + OGA_TRY + *value = params->GetBoolValue(name); return nullptr; OGA_CATCH } diff --git a/src/ort_genai_c.h b/src/ort_genai_c.h index 9063ea83da..9852f2c110 100644 --- a/src/ort_genai_c.h +++ b/src/ort_genai_c.h @@ -411,9 +411,23 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateGeneratorParams(const OgaModel* mode */ OGA_EXPORT void OGA_API_CALL OgaDestroyGeneratorParams(OgaGeneratorParams* params); +/** + * \brief Set a numerical value for a search parameter + * \param[in] params The generator params to set. + * \param[in] name The name of the search parameter. + * \param[in] value The value of the search parameter. + * \return OgaResult containing the error message if setting the generator params failed. + */ OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetSearchNumber(OgaGeneratorParams* params, const char* name, double value); + +/** + * \brief Set a boolean value for a search parameter + * \param[in] params The generator params to set. + * \param[in] name The name of the search parameter. + * \param[in] value The value of the search parameter. + * \return OgaResult containing the error message if setting the generator params failed. + */ OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetSearchBool(OgaGeneratorParams* params, const char* name, bool value); -OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsTryGraphCaptureWithMaxBatchSize(OgaGeneratorParams* params, int32_t max_batch_size); /** * \brief Sets the guidance type and data for the Generator params @@ -425,6 +439,24 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsTryGraphCaptureWithMaxBatch */ OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetGuidance(OgaGeneratorParams* params, const char* type, const char* data, bool enable_ff_tokens); +/** + * \brief Get a numerical value for a search parameter + * \param[in] params The generator params to set. + * \param[in] name The name of the search parameter. + * \param[out] value The value of the search parameter. + * \return OgaResult containing the error message if setting the generator params failed. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsGetSearchNumber(OgaGeneratorParams* params, const char* name, double* value); + +/** + * \brief Get a boolean value for a search parameter + * \param[in] params The generator params to set. + * \param[in] name The name of the search parameter. + * \param[out] value The value of the search parameter. + * \return OgaResult containing the error message if setting the generator params failed. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsGetSearchBool(OgaGeneratorParams* params, const char* name, bool* value); + /** * \brief Creates a generator from the given model and generator params. * \param[in] model The model to use for generation. @@ -632,17 +664,27 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaUpdateTokenizerOptions( size_t num_options); /** - * Return the int representation of the BOS token + * \brief Return the int representation of the BOS token + * \param[in] tokenizer The tokenizer to read from + * \param[out] token_id The BOS token id + * \return OgaResult containing the error message if returning the BOS token id fails. */ OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerGetBosTokenId(const OgaTokenizer* tokenizer, int32_t* token_id); /** - * Return an array containing the int representations of the EOS tokens. The array is owned by the tokenizer and will be freed when the tokenizer is destroyed. + * \brief Return an array containing the int representations of the EOS token ids. The array is owned by the tokenizer and will be freed when the tokenizer is destroyed. + * \param[in] tokenizer The tokenizer to read from + * \param[out] eos_token_ids The array of EOS token ids + * \param[out] token_count The length of the array + * \return OgaResult containing the error message if returning the EOS token ids fails. */ OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerGetEosTokenIds(const OgaTokenizer* tokenizer, const int32_t** eos_token_ids, size_t* token_count); /** - * Return the int representation of the PAD token + * \brief Return the int representation of the PAD token + * \param[in] tokenizer The tokenizer to read from + * \param[out] token_id The PAD token id + * \return OgaResult containing the error message if returning the PAD token id fails. */ OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerGetPadTokenId(const OgaTokenizer* tokenizer, int32_t* token_id); diff --git a/src/python/python.cpp b/src/python/python.cpp index db9591c8da..aa802fb7dc 100644 --- a/src/python/python.cpp +++ b/src/python/python.cpp @@ -192,10 +192,6 @@ struct PyGeneratorParams { } } - void TryGraphCaptureWithMaxBatchSize(pybind11::int_ max_batch_size) { - std::cerr << "TryGraphCaptureWithMaxBatchSize is deprecated and will be removed in a future release" << std::endl; - } - void SetGuidance(const std::string& type, const std::string& data, bool enable_ff_tokens = false) { params_->SetGuidance(type.c_str(), data.c_str(), enable_ff_tokens); } @@ -316,11 +312,31 @@ PYBIND11_MODULE(onnxruntime_genai, m) { pybind11::class_(m, "GeneratorParams") .def(pybind11::init()) - .def("try_graph_capture_with_max_batch_size", &PyGeneratorParams::TryGraphCaptureWithMaxBatchSize) .def("set_search_options", &PyGeneratorParams::SetSearchOptions) // See config.h 'struct Search' for the options .def("set_guidance", &PyGeneratorParams::SetGuidance, pybind11::arg("type"), pybind11::arg("data"), pybind11::arg("enable_ff_tokens") = false); + .def("get_search_options", [](const GeneratorParams& p) { + py::dict d; + d["batch_size"] = p.batch_size; + d["chunk_size"] = p.chunk_size; + d["diversity_penalty"] = p.diversity_penalty; + d["do_sample"] = p.do_sample; + d["early_stopping"] = p.early_stopping; + d["length_penalty"] = p.length_penalty; + d["max_length"] = p.max_length; + d["min_length"] = p.min_length; + d["no_repeat_ngram_size"] = p.no_repeat_ngram_size; + d["num_beams"] = p.num_beams; + d["num_return_sequences"] = p.num_return_sequences; + d["past_present_share_buffer"] = p.past_present_share_buffer; + d["random_seed"] = p.random_seed; + d["repetition_penalty"] = p.repetition_penalty; + d["temperature"] = p.temperature; + d["top_k"] = p.top_k; + d["top_p"] = p.top_p; + return d; + }) pybind11::class_(m, "TokenizerStream") .def("decode", [](OgaTokenizerStream& t, int32_t token) { return t.Decode(token); }); From d18fe38b6f89eff10ef4d74f953d5c72881313d3 Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Thu, 22 Jan 2026 01:26:34 +0000 Subject: [PATCH 11/38] Undo accidental change --- .../src/main/native/ai_onnxruntime_genai_GeneratorParams.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/java/src/main/native/ai_onnxruntime_genai_GeneratorParams.cpp b/src/java/src/main/native/ai_onnxruntime_genai_GeneratorParams.cpp index df29a3e7a1..139003ff43 100644 --- a/src/java/src/main/native/ai_onnxruntime_genai_GeneratorParams.cpp +++ b/src/java/src/main/native/ai_onnxruntime_genai_GeneratorParams.cpp @@ -35,8 +35,9 @@ Java_ai_onnxruntime_genai_GeneratorParams_setSearchOptionNumber(JNIEnv* env, job ThrowIfError(env, OgaGeneratorParamsSetSearchNumber(generator_params, name, value)); } -JNIEXPORT jdouble JNICALL -Java_ai_onnxruntime_genai_GeneratorParams_setSearchOptionBool(JNIEnv* env, jobject thiz, jlong native_handle, jstring option_name) { +JNIEXPORT void JNICALL +Java_ai_onnxruntime_genai_GeneratorParams_setSearchOptionBool(JNIEnv* env, jobject thiz, jlong native_handle, + jstring option_name, jboolean value) { OgaGeneratorParams* generator_params = reinterpret_cast(native_handle); CString name{env, option_name}; From 93b627928ec62085b9d96f5a1e529405e95acf33 Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Thu, 22 Jan 2026 01:37:18 +0000 Subject: [PATCH 12/38] Fix return types in Java bindings --- .../java/ai/onnxruntime/genai/GeneratorParams.java | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/java/src/main/java/ai/onnxruntime/genai/GeneratorParams.java b/src/java/src/main/java/ai/onnxruntime/genai/GeneratorParams.java index ce043234c8..c6fd3f4945 100644 --- a/src/java/src/main/java/ai/onnxruntime/genai/GeneratorParams.java +++ b/src/java/src/main/java/ai/onnxruntime/genai/GeneratorParams.java @@ -65,12 +65,12 @@ public void setSearchOption(String optionName, boolean value) throws GenAIExcept * @return The option value. * @throws GenAIException If the call to the GenAI native API fails. */ - public void getSearchNumber(String optionName) throws GenAIException { + public double getSearchNumber(String optionName) throws GenAIException { if (nativeHandle == 0) { throw new IllegalStateException("Instance has been freed and is invalid"); } - getSearchNumber(nativeHandle, optionName); + return getSearchNumber(nativeHandle, optionName); } /** @@ -80,12 +80,12 @@ public void getSearchNumber(String optionName) throws GenAIException { * @return The option value. * @throws GenAIException If the call to the GenAI native API fails. */ - public void getSearchBool(String optionName) throws GenAIException { + public boolean getSearchBool(String optionName) throws GenAIException { if (nativeHandle == 0) { throw new IllegalStateException("Instance has been freed and is invalid"); } - getSearchBool(nativeHandle, optionName); + return getSearchBool(nativeHandle, optionName); } @Override @@ -118,7 +118,7 @@ private native void setSearchOptionNumber(long nativeHandle, String optionName, private native void setSearchOptionBool(long nativeHandle, String optionName, boolean value) throws GenAIException; - private native void getSearchNumber(long nativeHandle, String optionName) throws GenAIException; + private native double getSearchNumber(long nativeHandle, String optionName) throws GenAIException; - private native void getSearchBool(long nativeHandle, String optionName) throws GenAIException; + private native boolean getSearchBool(long nativeHandle, String optionName) throws GenAIException; } From 87ef255abac503a792b4354932695133d7b30b8f Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Thu, 22 Jan 2026 01:39:16 +0000 Subject: [PATCH 13/38] Add missing value call --- src/generators.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/generators.cpp b/src/generators.cpp index 9abedd4574..9cb3ef127b 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -292,7 +292,7 @@ double GeneratorParams::GetSearchNumber(std::string_view name) const { if (name == "batch_size") { return static_cast(search.batch_size); } else if (name == "chunk_size" && search.chunk_size.has_value()) { - return static_cast(search.chunk_size); + return static_cast(search.chunk_size.value()); } else if (name == "diversity_penalty") { return search.diversity_penalty; } else if (name == "length_penalty") { From d5c51a179b4bc7b97714842585f9eb77c69719ee Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Thu, 22 Jan 2026 01:44:29 +0000 Subject: [PATCH 14/38] Update return type for Objective-C binding of TokenCount --- src/objectivec/include/ort_genai_objc.h | 4 ++-- src/objectivec/oga_generator.mm | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/objectivec/include/ort_genai_objc.h b/src/objectivec/include/ort_genai_objc.h index 7d615d5172..c3faf4d6c9 100644 --- a/src/objectivec/include/ort_genai_objc.h +++ b/src/objectivec/include/ort_genai_objc.h @@ -307,9 +307,9 @@ typedef NS_ENUM(NSInteger, OGAElementType) { /** * Get the number of tokens in the generator. * @param error Optional error information set if an error occurs. - * @return The number of tokens in the generator or int32_t(-1) if an error occurs. + * @return The number of tokens in the generator */ -- (int32_t)tokenCount:(NSError**)error; +- (int)tokenCount:(NSError**)error; /** * Rewinds the generator to the given length. diff --git a/src/objectivec/oga_generator.mm b/src/objectivec/oga_generator.mm index 9c6514fc6b..5b2c51adf4 100644 --- a/src/objectivec/oga_generator.mm +++ b/src/objectivec/oga_generator.mm @@ -67,11 +67,11 @@ - (BOOL)appendTokens:(NSArray*)tokens error:(NSError**)error { OGA_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) } -- (int32_t)tokenCount:(NSError**)error { +- (int)tokenCount:(NSError**)error { try { return _generator->TokenCount(); } - OGA_OBJC_API_IMPL_CATCH(error, int32_t(-1)) + OGA_OBJC_API_IMPL_CATCH_RETURNING_INT(error) } - (BOOL)rewindTo:(size_t)newLength error:(NSError**)error { From deac1b16ada63c86112991860c7abc43d2f7ca47 Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Thu, 22 Jan 2026 01:46:32 +0000 Subject: [PATCH 15/38] Add missing return in Java binding of TokenCount --- src/java/src/main/java/ai/onnxruntime/genai/Generator.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/java/src/main/java/ai/onnxruntime/genai/Generator.java b/src/java/src/main/java/ai/onnxruntime/genai/Generator.java index 1cd83f9464..301bff30bd 100644 --- a/src/java/src/main/java/ai/onnxruntime/genai/Generator.java +++ b/src/java/src/main/java/ai/onnxruntime/genai/Generator.java @@ -140,7 +140,7 @@ public int tokenCount() throws GenAIException { throw new IllegalStateException("Instance has been freed and is invalid"); } - tokenCount(nativeHandle); + return tokenCount(nativeHandle); } /** From 606b494655ee4eb69b46304a8258afca006ccf30 Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Thu, 22 Jan 2026 01:52:15 +0000 Subject: [PATCH 16/38] Fix names of APIs called in C API --- src/ort_genai_c.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ort_genai_c.cpp b/src/ort_genai_c.cpp index c2b652374a..149900517b 100644 --- a/src/ort_genai_c.cpp +++ b/src/ort_genai_c.cpp @@ -397,14 +397,14 @@ OgaResult* OGA_API_CALL OgaGeneratorParamsSetGuidance(OgaGeneratorParams* params OgaResult* OGA_API_CALL OgaGeneratorParamsGetSearchNumber(OgaGeneratorParams* params, const char* name, double* value) { OGA_TRY - *value = params->GetIntValue(name); + *value = params->GetSearchNumber(name); return nullptr; OGA_CATCH } OgaResult* OGA_API_CALL OgaGeneratorParamsGetSearchBool(OgaGeneratorParams* params, const char* name, bool* value) { OGA_TRY - *value = params->GetBoolValue(name); + *value = params->GetSearchBool(name); return nullptr; OGA_CATCH } From 279eaa2f641e82ba857c3cd7b53ac7e801e6370f Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Thu, 22 Jan 2026 02:11:57 +0000 Subject: [PATCH 17/38] Add missing const references --- src/generators.cpp | 2 +- src/ort_genai_c.cpp | 6 +++--- src/python/python.cpp | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/generators.cpp b/src/generators.cpp index 9cb3ef127b..b8e0c4b055 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -500,7 +500,7 @@ void Generator::SetRuntimeOption(const char* key, const char* value) { state_->SetRunOption(key, value); } -int32_t Generator::TokenCount() { +int32_t Generator::TokenCount() const { return static_cast(search_->GetSequenceLength()); } diff --git a/src/ort_genai_c.cpp b/src/ort_genai_c.cpp index 149900517b..86738d917a 100644 --- a/src/ort_genai_c.cpp +++ b/src/ort_genai_c.cpp @@ -395,14 +395,14 @@ OgaResult* OGA_API_CALL OgaGeneratorParamsSetGuidance(OgaGeneratorParams* params OGA_CATCH } -OgaResult* OGA_API_CALL OgaGeneratorParamsGetSearchNumber(OgaGeneratorParams* params, const char* name, double* value) { +OgaResult* OGA_API_CALL OgaGeneratorParamsGetSearchNumber(const OgaGeneratorParams* params, const char* name, double* value) { OGA_TRY *value = params->GetSearchNumber(name); return nullptr; OGA_CATCH } -OgaResult* OGA_API_CALL OgaGeneratorParamsGetSearchBool(OgaGeneratorParams* params, const char* name, bool* value) { +OgaResult* OGA_API_CALL OgaGeneratorParamsGetSearchBool(const OgaGeneratorParams* params, const char* name, bool* value) { OGA_TRY *value = params->GetSearchBool(name); return nullptr; @@ -464,7 +464,7 @@ OgaResult* OGA_API_CALL OgaGenerator_AppendTokens(OgaGenerator* generator, const OGA_CATCH } -OgaResult* OGA_API_CALL OgaGenerator_TokenCount(OgaGenerator* generator, int32_t* count) { +OgaResult* OGA_API_CALL OgaGenerator_TokenCount(const OgaGenerator* generator, int32_t* count) { OGA_TRY *count = generator->TokenCount(); return nullptr; diff --git a/src/python/python.cpp b/src/python/python.cpp index aa802fb7dc..129feaabe0 100644 --- a/src/python/python.cpp +++ b/src/python/python.cpp @@ -315,7 +315,7 @@ PYBIND11_MODULE(onnxruntime_genai, m) { .def("set_search_options", &PyGeneratorParams::SetSearchOptions) // See config.h 'struct Search' for the options .def("set_guidance", &PyGeneratorParams::SetGuidance, pybind11::arg("type"), pybind11::arg("data"), - pybind11::arg("enable_ff_tokens") = false); + pybind11::arg("enable_ff_tokens") = false) .def("get_search_options", [](const GeneratorParams& p) { py::dict d; d["batch_size"] = p.batch_size; @@ -336,7 +336,7 @@ PYBIND11_MODULE(onnxruntime_genai, m) { d["top_k"] = p.top_k; d["top_p"] = p.top_p; return d; - }) + }); pybind11::class_(m, "TokenizerStream") .def("decode", [](OgaTokenizerStream& t, int32_t token) { return t.Decode(token); }); From 67487c2d2b3924afa76823c7acba914bd7b92e0e Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Thu, 22 Jan 2026 02:15:47 +0000 Subject: [PATCH 18/38] Add changes suggested by C++ linter --- src/python/python.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/python/python.cpp b/src/python/python.cpp index 129feaabe0..bbe5627558 100644 --- a/src/python/python.cpp +++ b/src/python/python.cpp @@ -409,7 +409,8 @@ PYBIND11_MODULE(onnxruntime_genai, m) { return ToPython(sequences->Get(0)); }) .def("to_token_id", &OgaTokenizer::ToTokenId) .def("decode", [](const OgaTokenizer& t, pybind11::array_t tokens) -> std::string { return t.Decode(ToSpan(tokens)).p_; }) - .def("apply_chat_template", [](const OgaTokenizer& t, const char* messages, const char* template_str, const char* tools, bool add_generation_prompt) -> std::string { return t.ApplyChatTemplate(template_str, messages, tools, add_generation_prompt).p_; }, pybind11::arg("messages"), pybind11::kw_only(), pybind11::arg("template_str") = nullptr, pybind11::arg("tools") = nullptr, pybind11::arg("add_generation_prompt") = true) + .def( + "apply_chat_template", [](const OgaTokenizer& t, const char* messages, const char* template_str, const char* tools, bool add_generation_prompt) -> std::string { return t.ApplyChatTemplate(template_str, messages, tools, add_generation_prompt).p_; }, pybind11::arg("messages"), pybind11::kw_only(), pybind11::arg("template_str") = nullptr, pybind11::arg("tools") = nullptr, pybind11::arg("add_generation_prompt") = true) .def("encode_batch", [](const OgaTokenizer& t, std::vector strings) { std::vector c_strings; for (const auto& s : strings) From 09fc909ecfbb5def74a51d3a4479484d97db586f Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Thu, 22 Jan 2026 02:25:08 +0000 Subject: [PATCH 19/38] Add some more missing const references --- src/generators.h | 2 +- src/ort_genai_c.h | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/generators.h b/src/generators.h index 89806d8aec..1167b9d804 100644 --- a/src/generators.h +++ b/src/generators.h @@ -99,7 +99,7 @@ struct Generator : LeakChecked { Generator(const Model& model, const GeneratorParams& params); bool IsDone(); - int32_t TokenCount(); + int32_t TokenCount() const; void AppendTokens(cpu_span input_ids); void GenerateNextToken(); void RewindToLength(size_t new_length); // Rewind state to new_length diff --git a/src/ort_genai_c.h b/src/ort_genai_c.h index 9852f2c110..1b1fbac974 100644 --- a/src/ort_genai_c.h +++ b/src/ort_genai_c.h @@ -446,7 +446,7 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetGuidance(OgaGeneratorPar * \param[out] value The value of the search parameter. * \return OgaResult containing the error message if setting the generator params failed. */ -OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsGetSearchNumber(OgaGeneratorParams* params, const char* name, double* value); +OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsGetSearchNumber(const OgaGeneratorParams* params, const char* name, double* value); /** * \brief Get a boolean value for a search parameter @@ -455,7 +455,7 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsGetSearchNumber(OgaGenerato * \param[out] value The value of the search parameter. * \return OgaResult containing the error message if setting the generator params failed. */ -OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsGetSearchBool(OgaGeneratorParams* params, const char* name, bool* value); +OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsGetSearchBool(const OgaGeneratorParams* params, const char* name, bool* value); /** * \brief Creates a generator from the given model and generator params. @@ -525,7 +525,7 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_AppendTokens(OgaGenerator* gener * \param[out] count The number of tokens that have been added. * \return OgaResult containing the error message if the getting of the number of tokens failed. */ -OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_TokenCount(OgaGenerator* generator, int32_t* count); +OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_TokenCount(const OgaGenerator* generator, int32_t* count); /** * \brief Computes the logits from the model based on the input ids and the past state. The computed logits are stored in the generator. From 8d9fda815d8aaeb40107d0aba85c3925fb552b23 Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Thu, 22 Jan 2026 02:43:49 +0000 Subject: [PATCH 20/38] Change how Python binding is done --- src/python/python.cpp | 49 ++++++++++++++++++++++++------------------- 1 file changed, 27 insertions(+), 22 deletions(-) diff --git a/src/python/python.cpp b/src/python/python.cpp index bbe5627558..febd2c9421 100644 --- a/src/python/python.cpp +++ b/src/python/python.cpp @@ -196,6 +196,28 @@ struct PyGeneratorParams { params_->SetGuidance(type.c_str(), data.c_str(), enable_ff_tokens); } + py::dict GetSearchOptions() { + py::dict d; + d["batch_size"] = params_->GetSearchNumber("batch_size"); + d["chunk_size"] = params_->GetSearchNumber("chunk_size"); + d["diversity_penalty"] = params_->GetSearchNumber("diversity_penalty"); + d["do_sample"] = params_->GetSearchBool("do_sample"); + d["early_stopping"] = params_->GetSearchBool("early_stopping"); + d["length_penalty"] = params_->GetSearchNumber("length_penalty"); + d["max_length"] = params_->GetSearchNumber("max_length"); + d["min_length"] = params_->GetSearchNumber("min_length"); + d["no_repeat_ngram_size"] = params_->GetSearchNumber("no_repeat_ngram_size"); + d["num_beams"] = params_->GetSearchNumber("num_beams"); + d["num_return_sequences"] = params_->GetSearchNumber("num_return_sequences"); + d["past_present_share_buffer"] = params_->GetSearchBool("past_present_share_buffer"); + d["random_seed"] = params_->GetSearchNumber("random_seed"); + d["repetition_penalty"] = params_->GetSearchNumber("repetition_penalty"); + d["temperature"] = params_->GetSearchNumber("temperature"); + d["top_k"] = params_->GetSearchNumber("top_k"); + d["top_p"] = params_->GetSearchNumber("top_p"); + return d; + } + std::vector refs_; // References to data we want to ensure doesn't get garbage collected }; @@ -316,27 +338,7 @@ PYBIND11_MODULE(onnxruntime_genai, m) { .def("set_guidance", &PyGeneratorParams::SetGuidance, pybind11::arg("type"), pybind11::arg("data"), pybind11::arg("enable_ff_tokens") = false) - .def("get_search_options", [](const GeneratorParams& p) { - py::dict d; - d["batch_size"] = p.batch_size; - d["chunk_size"] = p.chunk_size; - d["diversity_penalty"] = p.diversity_penalty; - d["do_sample"] = p.do_sample; - d["early_stopping"] = p.early_stopping; - d["length_penalty"] = p.length_penalty; - d["max_length"] = p.max_length; - d["min_length"] = p.min_length; - d["no_repeat_ngram_size"] = p.no_repeat_ngram_size; - d["num_beams"] = p.num_beams; - d["num_return_sequences"] = p.num_return_sequences; - d["past_present_share_buffer"] = p.past_present_share_buffer; - d["random_seed"] = p.random_seed; - d["repetition_penalty"] = p.repetition_penalty; - d["temperature"] = p.temperature; - d["top_k"] = p.top_k; - d["top_p"] = p.top_p; - return d; - }); + .def("get_search_options", &PyGeneratorParams::GetSearchOptions); pybind11::class_(m, "TokenizerStream") .def("decode", [](OgaTokenizerStream& t, int32_t token) { return t.Decode(token); }); @@ -410,7 +412,10 @@ PYBIND11_MODULE(onnxruntime_genai, m) { .def("to_token_id", &OgaTokenizer::ToTokenId) .def("decode", [](const OgaTokenizer& t, pybind11::array_t tokens) -> std::string { return t.Decode(ToSpan(tokens)).p_; }) .def( - "apply_chat_template", [](const OgaTokenizer& t, const char* messages, const char* template_str, const char* tools, bool add_generation_prompt) -> std::string { return t.ApplyChatTemplate(template_str, messages, tools, add_generation_prompt).p_; }, pybind11::arg("messages"), pybind11::kw_only(), pybind11::arg("template_str") = nullptr, pybind11::arg("tools") = nullptr, pybind11::arg("add_generation_prompt") = true) + "apply_chat_template", [](const OgaTokenizer& t, const char* messages, const char* template_str, const char* tools, bool add_generation_prompt) -> std::string { + return t.ApplyChatTemplate(template_str, messages, tools, add_generation_prompt).p_; + }, + pybind11::arg("messages"), pybind11::kw_only(), pybind11::arg("template_str") = nullptr, pybind11::arg("tools") = nullptr, pybind11::arg("add_generation_prompt") = true) .def("encode_batch", [](const OgaTokenizer& t, std::vector strings) { std::vector c_strings; for (const auto& s : strings) From a95ac690f8b8ad7d1cb8074b481e7e0355127ac7 Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Thu, 22 Jan 2026 03:03:56 +0000 Subject: [PATCH 21/38] Use fullname for pybind dict --- src/python/python.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/python/python.cpp b/src/python/python.cpp index febd2c9421..8c18e56c75 100644 --- a/src/python/python.cpp +++ b/src/python/python.cpp @@ -196,8 +196,8 @@ struct PyGeneratorParams { params_->SetGuidance(type.c_str(), data.c_str(), enable_ff_tokens); } - py::dict GetSearchOptions() { - py::dict d; + pybind11::dict GetSearchOptions() { + pybind11::dict d; d["batch_size"] = params_->GetSearchNumber("batch_size"); d["chunk_size"] = params_->GetSearchNumber("chunk_size"); d["diversity_penalty"] = params_->GetSearchNumber("diversity_penalty"); From d70baa1dc658609b8090743949ac65cfde1438a2 Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Thu, 22 Jan 2026 03:07:20 +0000 Subject: [PATCH 22/38] Define TokenCount binding with PyGenerator instead of OgaGenerator --- src/python/python.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/python/python.cpp b/src/python/python.cpp index 8c18e56c75..2a089d6e55 100644 --- a/src/python/python.cpp +++ b/src/python/python.cpp @@ -258,6 +258,10 @@ struct PyGenerator { generator_->AppendTokens(ToSpan(tokens)); } + int32_t TokenCount() const { + return generator_->TokenCount(); + } + pybind11::array_t GetLogits() { return ToNumpy(*generator_->GetLogits()); } @@ -483,7 +487,7 @@ PYBIND11_MODULE(onnxruntime_genai, m) { .def("set_model_input", &PyGenerator::SetModelInput) .def("append_tokens", pybind11::overload_cast&>(&PyGenerator::AppendTokens)) .def("append_tokens", pybind11::overload_cast(&PyGenerator::AppendTokens)) - .def("token_count", &OgaGenerator::TokenCount) + .def("token_count", &PyGenerator::TokenCount) .def("get_logits", &PyGenerator::GetLogits) .def("set_logits", &PyGenerator::SetLogits) .def("generate_next_token", &PyGenerator::GenerateNextToken) From d160ea07e4cb26856808c31f5c4899bacb62fea7 Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Thu, 22 Jan 2026 03:57:52 +0000 Subject: [PATCH 23/38] Add changes suggested by C++ linter --- src/python/python.cpp | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/python/python.cpp b/src/python/python.cpp index 2a089d6e55..ee2acaef13 100644 --- a/src/python/python.cpp +++ b/src/python/python.cpp @@ -416,10 +416,7 @@ PYBIND11_MODULE(onnxruntime_genai, m) { .def("to_token_id", &OgaTokenizer::ToTokenId) .def("decode", [](const OgaTokenizer& t, pybind11::array_t tokens) -> std::string { return t.Decode(ToSpan(tokens)).p_; }) .def( - "apply_chat_template", [](const OgaTokenizer& t, const char* messages, const char* template_str, const char* tools, bool add_generation_prompt) -> std::string { - return t.ApplyChatTemplate(template_str, messages, tools, add_generation_prompt).p_; - }, - pybind11::arg("messages"), pybind11::kw_only(), pybind11::arg("template_str") = nullptr, pybind11::arg("tools") = nullptr, pybind11::arg("add_generation_prompt") = true) + "apply_chat_template", [](const OgaTokenizer& t, const char* messages, const char* template_str, const char* tools, bool add_generation_prompt) -> std::string { return t.ApplyChatTemplate(template_str, messages, tools, add_generation_prompt).p_; }, pybind11::arg("messages"), pybind11::kw_only(), pybind11::arg("template_str") = nullptr, pybind11::arg("tools") = nullptr, pybind11::arg("add_generation_prompt") = true) .def("encode_batch", [](const OgaTokenizer& t, std::vector strings) { std::vector c_strings; for (const auto& s : strings) From 36d6e3d0d53fa8c8c08f7f4188c6e7bb60300db6 Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Thu, 22 Jan 2026 04:16:43 +0000 Subject: [PATCH 24/38] Move ApplyChatTemplate into one line and ignore local C++ linter --- src/python/python.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/python/python.cpp b/src/python/python.cpp index ee2acaef13..e56f9f1620 100644 --- a/src/python/python.cpp +++ b/src/python/python.cpp @@ -415,8 +415,7 @@ PYBIND11_MODULE(onnxruntime_genai, m) { return ToPython(sequences->Get(0)); }) .def("to_token_id", &OgaTokenizer::ToTokenId) .def("decode", [](const OgaTokenizer& t, pybind11::array_t tokens) -> std::string { return t.Decode(ToSpan(tokens)).p_; }) - .def( - "apply_chat_template", [](const OgaTokenizer& t, const char* messages, const char* template_str, const char* tools, bool add_generation_prompt) -> std::string { return t.ApplyChatTemplate(template_str, messages, tools, add_generation_prompt).p_; }, pybind11::arg("messages"), pybind11::kw_only(), pybind11::arg("template_str") = nullptr, pybind11::arg("tools") = nullptr, pybind11::arg("add_generation_prompt") = true) + .def("apply_chat_template", [](const OgaTokenizer& t, const char* messages, const char* template_str, const char* tools, bool add_generation_prompt) -> std::string { return t.ApplyChatTemplate(template_str, messages, tools, add_generation_prompt).p_; }, pybind11::arg("messages"), pybind11::kw_only(), pybind11::arg("template_str") = nullptr, pybind11::arg("tools") = nullptr, pybind11::arg("add_generation_prompt") = true) .def("encode_batch", [](const OgaTokenizer& t, std::vector strings) { std::vector c_strings; for (const auto& s : strings) From ebe9eaf90251025e95fa00b43bcbc2c6a5432742 Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Thu, 22 Jan 2026 06:39:51 +0000 Subject: [PATCH 25/38] Add assertions in unit tests for new APIs --- .../src/test/java/ai/onnxruntime/genai/GenerationTest.java | 6 ++++++ .../java/ai/onnxruntime/genai/GeneratorParamsTest.java | 2 +- test/c_api_tests.cpp | 5 +++++ test/csharp/TestOnnxRuntimeGenAIAPI.cs | 7 ++++++- .../ios_package_testUITests/ios_package_uitest_cpp_api.mm | 5 +++++ .../macos_package_uitest_cpp_api.mm | 5 +++++ test/python/test_onnxruntime_genai_api.py | 7 ++++++- 7 files changed, 34 insertions(+), 3 deletions(-) diff --git a/src/java/src/test/java/ai/onnxruntime/genai/GenerationTest.java b/src/java/src/test/java/ai/onnxruntime/genai/GenerationTest.java index beb0a12ad9..4f22807529 100644 --- a/src/java/src/test/java/ai/onnxruntime/genai/GenerationTest.java +++ b/src/java/src/test/java/ai/onnxruntime/genai/GenerationTest.java @@ -138,6 +138,11 @@ public void testWithInputIds() throws GenAIException { try (Generator generator = new Generator(model, params); ) { generator.appendTokens(inputIDs); + + assertEquals(params.GetSearchNumber("max_length"), maxLength); + assertEquals(params.GetSearchBool("early_stopping"), true); + assertEquals(generator.TokenCount(), 10); + while (!generator.isDone()) { generator.generateNextToken(); } @@ -148,6 +153,7 @@ public void testWithInputIds() throws GenAIException { assertEquals(outputIds[j], expectedOutput[i * maxLength + j]); } } + assertEquals(generator.TokenCount(), generator.GetSequenceCount(0).length); } } } diff --git a/src/java/src/test/java/ai/onnxruntime/genai/GeneratorParamsTest.java b/src/java/src/test/java/ai/onnxruntime/genai/GeneratorParamsTest.java index cd8698de38..944d3c4047 100644 --- a/src/java/src/test/java/ai/onnxruntime/genai/GeneratorParamsTest.java +++ b/src/java/src/test/java/ai/onnxruntime/genai/GeneratorParamsTest.java @@ -12,7 +12,7 @@ public class GeneratorParamsTest { @Test public void testValidSearchOption() throws GenAIException { - // test setting an invalid search option throws a GenAIException + // test setting a valid search option try (SimpleGenAI generator = new SimpleGenAI(TestUtils.tinyGpt2ModelPath()); GeneratorParams params = generator.createGeneratorParams(); ) { params.setSearchOption("early_stopping", true); // boolean diff --git a/test/c_api_tests.cpp b/test/c_api_tests.cpp index d8e38dc8f5..330ff1b43c 100644 --- a/test/c_api_tests.cpp +++ b/test/c_api_tests.cpp @@ -555,6 +555,10 @@ TEST(CAPITests, EndToEndPhiEOSPAD) { auto generator = OgaGenerator::Create(*model, *params); generator->AppendTokenSequences(*input_sequence); + ASSERT_EQ(static_cast(params->GetSearchNumber("max_length")), 40); + ASSERT_EQ(params->GetSearchBool("early_stopping"), true); + ASSERT_EQ(static_cast(generator->TokenCount()), static_cast(generator->GetSequenceCount(0))); + while (!generator->IsDone()) { generator->GenerateNextToken(); } @@ -573,6 +577,7 @@ TEST(CAPITests, EndToEndPhiEOSPAD) { const auto* sequence_data = generator->GetSequenceData(0); ASSERT_LE(sequence_length, 40); + ASSERT_EQ(static_cast(generator->TokenCount()), static_cast(generator->GetSequenceCount(0))); const auto* expected_output_start = &expected_output[0]; EXPECT_TRUE(0 == std::memcmp(expected_output_start, sequence_data, sequence_length * sizeof(int32_t))); diff --git a/test/csharp/TestOnnxRuntimeGenAIAPI.cs b/test/csharp/TestOnnxRuntimeGenAIAPI.cs index dec620d774..71368409b4 100644 --- a/test/csharp/TestOnnxRuntimeGenAIAPI.cs +++ b/test/csharp/TestOnnxRuntimeGenAIAPI.cs @@ -158,9 +158,13 @@ public void TestGreedySearch() using (var generator = new Generator(model, generatorParams)) { Assert.NotNull(generator); - generator.AppendTokens(inputIDs); + Assert.False(generator.IsDone()); + Assert.Equal(generatorParams.GetSearchNumber("max_length"), maxLength); + Assert.Equal(generatorParams.GetSearchBool("early_stopping"), true); + Assert.Equal(generator.TokenCount(), generator.GetSequence(0).Length); + while (!generator.IsDone()) { generator.GenerateNextToken(); @@ -171,6 +175,7 @@ public void TestGreedySearch() var sequence = generator.GetSequence(i).ToArray(); var expectedSequence = expectedOutput.Skip((int)i * (int)maxLength).Take((int)maxLength); Assert.Equal(expectedSequence, sequence); + Assert.Equal(generator.TokenCount(), generator.GetSequence(i).Length); } } } diff --git a/test/platform/apple/apple_package_test/ios_package_testUITests/ios_package_uitest_cpp_api.mm b/test/platform/apple/apple_package_test/ios_package_testUITests/ios_package_uitest_cpp_api.mm index 40a46aaf03..25ed5e8172 100644 --- a/test/platform/apple/apple_package_test/ios_package_testUITests/ios_package_uitest_cpp_api.mm +++ b/test/platform/apple/apple_package_test/ios_package_testUITests/ios_package_uitest_cpp_api.mm @@ -54,6 +54,10 @@ - (void)testCppAPI_Basic { auto generator = OgaGenerator::Create(*model, *params); generator->AppendTokenSequences(*sequences); + XCTAssertEqual(static_cast(params->GetSearchNumber("max_length")), 100); + XCTAssertEqual(params->GetSearchBool("early_stopping"), true); + XCTAssertEqual(static_cast(generator->TokenCount()), static_cast(generator->GetSequenceCount(0))); + while (!generator->IsDone()) { generator->GenerateNextToken(); } @@ -61,6 +65,7 @@ - (void)testCppAPI_Basic { const auto output_sequence_length = generator->GetSequenceCount(0); const auto* output_sequence_data = generator->GetSequenceData(0); auto out_string = tokenizer->Decode(output_sequence_data, output_sequence_length); + XCTAssertEqual(static_cast(generator->TokenCount()), static_cast(generator->GetSequenceCount(0))); } @end diff --git a/test/platform/apple/apple_package_test/macos_package_testUITests/macos_package_uitest_cpp_api.mm b/test/platform/apple/apple_package_test/macos_package_testUITests/macos_package_uitest_cpp_api.mm index af5d0046ec..538289c322 100644 --- a/test/platform/apple/apple_package_test/macos_package_testUITests/macos_package_uitest_cpp_api.mm +++ b/test/platform/apple/apple_package_test/macos_package_testUITests/macos_package_uitest_cpp_api.mm @@ -54,6 +54,10 @@ - (void)testCppAPI_Basic { auto generator = OgaGenerator::Create(*model, *params); generator->AppendTokenSequences(*sequences); + XCTAssertEqual(static_cast(params->GetSearchNumber("max_length")), 100); + XCTAssertEqual(params->GetSearchBool("early_stopping"), true); + XCTAssertEqual(static_cast(generator->TokenCount()), static_cast(generator->GetSequenceCount(0))); + while (!generator->IsDone()) { generator->GenerateNextToken(); } @@ -61,6 +65,7 @@ - (void)testCppAPI_Basic { const auto output_sequence_length = generator->GetSequenceCount(0); const auto* output_sequence_data = generator->GetSequenceData(0); auto out_string = tokenizer->Decode(output_sequence_data, output_sequence_length); + XCTAssertEqual(static_cast(generator->TokenCount()), static_cast(generator->GetSequenceCount(0))); } @end diff --git a/test/python/test_onnxruntime_genai_api.py b/test/python/test_onnxruntime_genai_api.py index 5d33a73aaa..00374ad0f7 100644 --- a/test/python/test_onnxruntime_genai_api.py +++ b/test/python/test_onnxruntime_genai_api.py @@ -148,6 +148,11 @@ def test_greedy_search(test_data_path, relative_model_path): generator = og.Generator(model, search_params) generator.append_tokens(np.array([[0, 0, 0, 52], [0, 0, 195, 731]], dtype=np.int32)) + + assert int(search_params.get_search_options["max_length"]) == 40 + assert search_params.get_search_options["early_stopping"] == True + assert int(generator.token_count()) == 4 + while not generator.is_done(): # Test getting/setting logits logits = generator.get_logits() @@ -165,7 +170,7 @@ def test_greedy_search(test_data_path, relative_model_path): ) for i in range(batch_size): assert np.array_equal(expected_sequence[i], generator.get_sequence(i)) - + assert int(generator.token_count()) == len(generator.get_sequence(0)) @pytest.mark.parametrize( "relative_model_path", From e4bc16914fddb6ed4dee63bb05b2e2ba06e1b29b Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Thu, 22 Jan 2026 06:55:47 +0000 Subject: [PATCH 26/38] Fix language binding API names --- .../test/java/ai/onnxruntime/genai/GenerationTest.java | 8 ++++---- test/python/test_onnxruntime_genai_api.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/java/src/test/java/ai/onnxruntime/genai/GenerationTest.java b/src/java/src/test/java/ai/onnxruntime/genai/GenerationTest.java index 4f22807529..56c543a8ea 100644 --- a/src/java/src/test/java/ai/onnxruntime/genai/GenerationTest.java +++ b/src/java/src/test/java/ai/onnxruntime/genai/GenerationTest.java @@ -139,9 +139,9 @@ public void testWithInputIds() throws GenAIException { try (Generator generator = new Generator(model, params); ) { generator.appendTokens(inputIDs); - assertEquals(params.GetSearchNumber("max_length"), maxLength); - assertEquals(params.GetSearchBool("early_stopping"), true); - assertEquals(generator.TokenCount(), 10); + assertEquals(params.getSearchNumber("max_length"), maxLength); + assertEquals(params.getSearchBool("early_stopping"), true); + assertEquals(generator.tokenCount(), 10); while (!generator.isDone()) { generator.generateNextToken(); @@ -153,7 +153,7 @@ public void testWithInputIds() throws GenAIException { assertEquals(outputIds[j], expectedOutput[i * maxLength + j]); } } - assertEquals(generator.TokenCount(), generator.GetSequenceCount(0).length); + assertEquals(generator.tokenCount(), generator.getSequence(0).length); } } } diff --git a/test/python/test_onnxruntime_genai_api.py b/test/python/test_onnxruntime_genai_api.py index 00374ad0f7..eff69bd06c 100644 --- a/test/python/test_onnxruntime_genai_api.py +++ b/test/python/test_onnxruntime_genai_api.py @@ -149,8 +149,8 @@ def test_greedy_search(test_data_path, relative_model_path): generator = og.Generator(model, search_params) generator.append_tokens(np.array([[0, 0, 0, 52], [0, 0, 195, 731]], dtype=np.int32)) - assert int(search_params.get_search_options["max_length"]) == 40 - assert search_params.get_search_options["early_stopping"] == True + assert int(search_params.get_search_options()["max_length"]) == 40 + assert search_params.get_search_options()["early_stopping"] == True assert int(generator.token_count()) == 4 while not generator.is_done(): From 5ca18c7a48266b39c18a207eb0ec1e38ad124355 Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Thu, 22 Jan 2026 09:18:59 +0000 Subject: [PATCH 27/38] Update how chunk size is obtained --- src/generators.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/generators.cpp b/src/generators.cpp index b8e0c4b055..6e13b3a9cc 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -291,8 +291,8 @@ bool GeneratorParams::IsPastPresentShareBufferEnabled(const std::string& model_t double GeneratorParams::GetSearchNumber(std::string_view name) const { if (name == "batch_size") { return static_cast(search.batch_size); - } else if (name == "chunk_size" && search.chunk_size.has_value()) { - return static_cast(search.chunk_size.value()); + } else if (name == "chunk_size") { + return search.chunk_size.has_value() ? static_cast(search.chunk_size.value()) : 0.0; } else if (name == "diversity_penalty") { return search.diversity_penalty; } else if (name == "length_penalty") { @@ -318,7 +318,7 @@ double GeneratorParams::GetSearchNumber(std::string_view name) const { } else if (name == "top_p") { return search.top_p; } else { - throw std::runtime_error("Invalid name for GetSearchNumber."); + throw std::runtime_error(std::string(name) + " is an invalid name for GetSearchNumber."); } } @@ -330,7 +330,7 @@ bool GeneratorParams::GetSearchBool(std::string_view name) const { } else if (name == "past_present_share_buffer") { return search.past_present_share_buffer; } else { - throw std::runtime_error("Invalid name for GetSearchBool."); + throw std::runtime_error(std::string(name) + " is an invalid name for GetSearchBool."); } } From 6f6c0b0eec3d0c0d2bcb7f33640616421344eddf Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Thu, 22 Jan 2026 09:38:51 +0000 Subject: [PATCH 28/38] Fix max length in assertion --- test/python/test_onnxruntime_genai_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/python/test_onnxruntime_genai_api.py b/test/python/test_onnxruntime_genai_api.py index eff69bd06c..837bd4726a 100644 --- a/test/python/test_onnxruntime_genai_api.py +++ b/test/python/test_onnxruntime_genai_api.py @@ -149,7 +149,7 @@ def test_greedy_search(test_data_path, relative_model_path): generator = og.Generator(model, search_params) generator.append_tokens(np.array([[0, 0, 0, 52], [0, 0, 195, 731]], dtype=np.int32)) - assert int(search_params.get_search_options()["max_length"]) == 40 + assert int(search_params.get_search_options()["max_length"]) == 10 assert search_params.get_search_options()["early_stopping"] == True assert int(generator.token_count()) == 4 From 1c2ed9d9dbd5060327128a6b578d3a8abfe80a50 Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Thu, 22 Jan 2026 10:25:55 +0000 Subject: [PATCH 29/38] Remove default values from Java bindings --- src/config.h | 2 +- .../src/main/native/ai_onnxruntime_genai_GeneratorParams.cpp | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/config.h b/src/config.h index a9a8073745..51358b8192 100644 --- a/src/config.h +++ b/src/config.h @@ -301,7 +301,7 @@ struct Config { int top_k{50}; // Number of highest probability vocabulary tokens to keep for top-k-filtering that will be used by default in the generate method of the model. float top_p{}; // If set to float >0 and <1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation. float temperature{1.0f}; // Temperature to control during generation. Default is 1.0. - bool early_stopping{true}; // Whether to stop the beam search when at least num_beams sentences are finished per batch or not. + bool early_stopping{true}; // Whether to stop the beam search when at least num_beams sentences are finished per batch or not. int no_repeat_ngram_size{}; // Unused param float diversity_penalty{}; // Unused param float length_penalty{1.0f}; // Exponential penalty to the length that is used with beam-based generation. length_penalty > 0.0 promotes longer sequences, while length_penalty < 0.0 encourages shorter sequences. diff --git a/src/java/src/main/native/ai_onnxruntime_genai_GeneratorParams.cpp b/src/java/src/main/native/ai_onnxruntime_genai_GeneratorParams.cpp index 139003ff43..400b30b699 100644 --- a/src/java/src/main/native/ai_onnxruntime_genai_GeneratorParams.cpp +++ b/src/java/src/main/native/ai_onnxruntime_genai_GeneratorParams.cpp @@ -48,7 +48,7 @@ JNIEXPORT jdouble JNICALL Java_ai_onnxruntime_genai_GeneratorParams_getSearchNumber(JNIEnv* env, jobject thiz, jlong native_handle, jstring option_name) { OgaGeneratorParams* generator_params = reinterpret_cast(native_handle); CString name{env, option_name}; - double value = 0.0; + double value; ThrowIfError(env, OgaGeneratorParamsGetSearchNumber(generator_params, name, &value)); return static_cast(value); @@ -58,7 +58,7 @@ JNIEXPORT jboolean JNICALL Java_ai_onnxruntime_genai_GeneratorParams_getSearchBool(JNIEnv* env, jobject thiz, jlong native_handle, jstring option_name) { OgaGeneratorParams* generator_params = reinterpret_cast(native_handle); CString name{env, option_name}; - bool value = false; + bool value; ThrowIfError(env, OgaGeneratorParamsSetSearchBool(generator_params, name, &value)); return static_cast(value); From fb1f8d5d2cfc63d4cc77a09f9a9e452162980fdd Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Thu, 22 Jan 2026 10:26:36 +0000 Subject: [PATCH 30/38] Fix C API call inside Java API --- .../src/main/native/ai_onnxruntime_genai_GeneratorParams.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/java/src/main/native/ai_onnxruntime_genai_GeneratorParams.cpp b/src/java/src/main/native/ai_onnxruntime_genai_GeneratorParams.cpp index 400b30b699..907ab93040 100644 --- a/src/java/src/main/native/ai_onnxruntime_genai_GeneratorParams.cpp +++ b/src/java/src/main/native/ai_onnxruntime_genai_GeneratorParams.cpp @@ -60,6 +60,6 @@ Java_ai_onnxruntime_genai_GeneratorParams_getSearchBool(JNIEnv* env, jobject thi CString name{env, option_name}; bool value; - ThrowIfError(env, OgaGeneratorParamsSetSearchBool(generator_params, name, &value)); + ThrowIfError(env, OgaGeneratorParamsGetSearchBool(generator_params, name, &value)); return static_cast(value); } From 2197b1a0c705b089893953e98cb8c2d9a5e339e7 Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Thu, 22 Jan 2026 10:56:54 +0000 Subject: [PATCH 31/38] Fix value in assert --- src/java/src/test/java/ai/onnxruntime/genai/GenerationTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/java/src/test/java/ai/onnxruntime/genai/GenerationTest.java b/src/java/src/test/java/ai/onnxruntime/genai/GenerationTest.java index 56c543a8ea..bc25395a1f 100644 --- a/src/java/src/test/java/ai/onnxruntime/genai/GenerationTest.java +++ b/src/java/src/test/java/ai/onnxruntime/genai/GenerationTest.java @@ -141,7 +141,7 @@ public void testWithInputIds() throws GenAIException { assertEquals(params.getSearchNumber("max_length"), maxLength); assertEquals(params.getSearchBool("early_stopping"), true); - assertEquals(generator.tokenCount(), 10); + assertEquals(generator.tokenCount(), 4); while (!generator.isDone()) { generator.generateNextToken(); From 7afeb518f03b271c74916b25e6ab318d5087ec10 Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Thu, 22 Jan 2026 23:04:48 +0000 Subject: [PATCH 32/38] Remove breaking changes documentation from README --- README.md | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/README.md b/README.md index f57edf9248..29313d8feb 100644 --- a/README.md +++ b/README.md @@ -150,30 +150,6 @@ To install the nightly Python build: pip install --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/ onnxruntime-genai ``` -## Breaking API changes - -Note: between `v0.11.5` and `v0.10.1`, there was a breaking API usage change to improve model quality during multi-turn conversations. In the process, the generation loop was changed to the following. - -``` -while True: - GenerateToken() - if IsDone(): - break - GetLastToken() - PrintLastToken() -``` - -With the addition of the GetLastToken() API across language bindings, the prior behavior for the decoding loop is restore. The decoding loop can now be written once again as follows. - -``` -while not IsDone(): - GenerateToken() - GetLastToken() - PrintLastToken() -``` - -Please read [this PR's description](https://github.com/microsoft/onnxruntime-genai/pull/1925) for more information. - ## Roadmap See the [Discussions](https://github.com/microsoft/onnxruntime-genai/discussions) to request new features and up-vote existing requests. From 3b990dfc6fe72d0f0fe1693f9bf36d50457bb076 Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Sat, 24 Jan 2026 00:44:08 +0000 Subject: [PATCH 33/38] Construct vector in return statement --- src/ort_genai.h | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/ort_genai.h b/src/ort_genai.h index 73f8d64344..270421b6bd 100644 --- a/src/ort_genai.h +++ b/src/ort_genai.h @@ -324,12 +324,10 @@ struct OgaTokenizer : OgaAbstract { } #else std::vector GetEosTokenIds() const { - std::vector eos_ids; const int32_t* eos_ids_ptr; size_t count; OgaCheckResult(OgaTokenizerGetEosTokenIds(this, &eos_ids_ptr, &count)); - eos_ids.assign(eos_ids_ptr, eos_ids_ptr + count); - return eos_ids; + return std::vector(eos_ids_ptr, eos_ids_ptr + count); } #endif @@ -501,12 +499,10 @@ struct OgaGenerator : OgaAbstract { } #else std::vector GetNextTokens() { - std::vector next_tokens; const int32_t* out; size_t out_count; OgaCheckResult(OgaGenerator_GetNextTokens(this, &out, &out_count)); - next_tokens.assign(out, out + out_count); - return next_tokens; + return std::vector(out, out + out_count); } #endif From 27d42b5c89c4b77319b04f86b05f71f9a2577403 Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Mon, 26 Jan 2026 22:48:21 +0000 Subject: [PATCH 34/38] Make changes based on PR feedback --- src/csharp/Generator.cs | 9 ++++++--- src/csharp/NativeMethods.cs | 4 ++-- src/cuda/search_cuda.cpp | 2 -- src/cuda/search_cuda.h | 2 +- src/generators.cpp | 4 ++-- src/generators.h | 2 +- .../src/main/java/ai/onnxruntime/genai/Generator.java | 4 ++-- .../src/main/native/ai_onnxruntime_genai_Generator.cpp | 8 +++----- src/objectivec/error_utils.h | 5 +++-- src/objectivec/include/ort_genai_objc.h | 2 +- src/objectivec/oga_generator.mm | 6 +++--- src/objectivec/oga_sequences.mm | 4 ++-- src/objectivec/oga_tokenizer.mm | 4 ++-- src/ort_genai.h | 6 ++---- src/ort_genai_c.cpp | 7 ++----- src/ort_genai_c.h | 5 ++--- src/python/python.cpp | 2 +- src/search.h | 4 +--- 18 files changed, 36 insertions(+), 44 deletions(-) diff --git a/src/csharp/Generator.cs b/src/csharp/Generator.cs index 30f9ffbdb3..94f83f0798 100644 --- a/src/csharp/Generator.cs +++ b/src/csharp/Generator.cs @@ -46,10 +46,13 @@ public void AppendTokenSequences(Sequences sequences) Result.VerifySuccess(NativeMethods.OgaGenerator_AppendTokenSequences(_generatorHandle, sequences.Handle)); } - public int TokenCount() + /// + /// Gets the number of tokens in the generator + /// + /// The token count + public ulong TokenCount() { - Result.VerifySuccess(NativeMethods.OgaGenerator_TokenCount(_generatorHandle, out int count)); - return count; + return NativeMethods.OgaGenerator_TokenCount(_generatorHandle).ToUInt64(); } public void GenerateNextToken() diff --git a/src/csharp/NativeMethods.cs b/src/csharp/NativeMethods.cs index 9f1823e080..1ba9aac906 100644 --- a/src/csharp/NativeMethods.cs +++ b/src/csharp/NativeMethods.cs @@ -168,8 +168,8 @@ internal class NativeLib // This function is used to get the number of tokens in the generator. [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)] - public static extern IntPtr /* OgaResult* */ OgaGenerator_TokenCount(IntPtr /* OgaGenerator* */ generator, - out int /* int32_t* */ count); + public static extern UIntPtr OgaGenerator_TokenCount(IntPtr /* const OgaGenerator* */ generator); + // This function is used to rewind the generator to the given newLength. [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)] diff --git a/src/cuda/search_cuda.cpp b/src/cuda/search_cuda.cpp index 36fa4a6ac9..b306d0f473 100644 --- a/src/cuda/search_cuda.cpp +++ b/src/cuda/search_cuda.cpp @@ -242,12 +242,10 @@ void GreedySearch_Cuda::AppendTokens(DeviceSpan& next_tokens) { } void BeamSearch_Cuda::AppendTokens(DeviceSpan& next_tokens) { - ResetDone(); auto next_tokens_gpu = next_tokens.Span(); cuda::Launch_ExpandInputSequences(next_tokens_gpu, sequences_.GetNextSequences().Span(), params_->search.batch_size, params_->search.num_beams, sequences_.max_length_, GetStream()); cuda::Launch_ExpandInputSequences(next_tokens_gpu, sequences_.GetSequences().Span(), params_->search.batch_size, params_->search.num_beams, sequences_.max_length_, GetStream()); sequences_.AfterAppendNextTokens(next_tokens, params_->search.batch_size); // next_tokens is batch_size - ResetDone(); cudaStreamSynchronize(GetStream()); } diff --git a/src/cuda/search_cuda.h b/src/cuda/search_cuda.h index 6bcc398ff8..b5ead4c28c 100644 --- a/src/cuda/search_cuda.h +++ b/src/cuda/search_cuda.h @@ -19,7 +19,7 @@ struct Search_Cuda : Search { cudaStreamSynchronize(GetStream()); return *done_cpu_; } // TODO: Use an event - void ResetDone() override; + void ResetDone(); DeviceSpan GetLogits() const override; void SetLogits(DeviceSpan logits) override; diff --git a/src/generators.cpp b/src/generators.cpp index 6e13b3a9cc..412741b378 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -500,8 +500,8 @@ void Generator::SetRuntimeOption(const char* key, const char* value) { state_->SetRunOption(key, value); } -int32_t Generator::TokenCount() const { - return static_cast(search_->GetSequenceLength()); +size_t Generator::TokenCount() const { + return static_cast(search_->GetSequenceLength()); } bool Generator::IsDone() { diff --git a/src/generators.h b/src/generators.h index 1167b9d804..ccfed8db44 100644 --- a/src/generators.h +++ b/src/generators.h @@ -99,7 +99,7 @@ struct Generator : LeakChecked { Generator(const Model& model, const GeneratorParams& params); bool IsDone(); - int32_t TokenCount() const; + size_t TokenCount() const; void AppendTokens(cpu_span input_ids); void GenerateNextToken(); void RewindToLength(size_t new_length); // Rewind state to new_length diff --git a/src/java/src/main/java/ai/onnxruntime/genai/Generator.java b/src/java/src/main/java/ai/onnxruntime/genai/Generator.java index 301bff30bd..13869c25ca 100644 --- a/src/java/src/main/java/ai/onnxruntime/genai/Generator.java +++ b/src/java/src/main/java/ai/onnxruntime/genai/Generator.java @@ -135,7 +135,7 @@ public void appendTokenSequences(Sequences sequences) throws GenAIException { * * @throws GenAIException If the call to the GenAI native API fails. */ - public int tokenCount() throws GenAIException { + public long tokenCount() throws GenAIException { if (nativeHandle == 0) { throw new IllegalStateException("Instance has been freed and is invalid"); } @@ -293,7 +293,7 @@ private native void setModelInput(long nativeHandle, String inputName, long tens private native void appendTokenSequences(long nativeHandle, long sequencesHandle) throws GenAIException; - private native int tokenCount(long nativeHandle) throws GenAIException; + private native long tokenCount(long nativeHandle) throws GenAIException; private native void rewindTo(long nativeHandle, long newLength) throws GenAIException; diff --git a/src/java/src/main/native/ai_onnxruntime_genai_Generator.cpp b/src/java/src/main/native/ai_onnxruntime_genai_Generator.cpp index 9ad21a0ac1..5327683ecb 100644 --- a/src/java/src/main/native/ai_onnxruntime_genai_Generator.cpp +++ b/src/java/src/main/native/ai_onnxruntime_genai_Generator.cpp @@ -67,13 +67,11 @@ Java_ai_onnxruntime_genai_Generator_appendTokens(JNIEnv* env, jobject thiz, jlon env->ReleaseIntArrayElements(token_ids, tokens, JNI_ABORT); } -JNIEXPORT jint JNICALL +JNIEXPORT jlong JNICALL Java_ai_onnxruntime_genai_Generator_tokenCount(JNIEnv* env, jobject thiz, jlong native_handle) { OgaGenerator* generator = reinterpret_cast(native_handle); - int32_t count = 0; - - ThrowIfError(env, OgaGenerator_TokenCount(generator, &count)); - return static_cast(count); + size_t count = OgaGenerator_TokenCount(generator); + return static_cast(count); } JNIEXPORT jboolean JNICALL diff --git a/src/objectivec/error_utils.h b/src/objectivec/error_utils.h index 9359a23cf5..43672ee465 100644 --- a/src/objectivec/error_utils.h +++ b/src/objectivec/error_utils.h @@ -23,8 +23,9 @@ void OGASaveExceptionToError(const std::exception& e, NSError** error); } #define OGA_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) OGA_OBJC_API_IMPL_CATCH(error, NO) -#define OGA_OBJC_API_IMPL_CATCH_RETURNING_DOUBLE (error) OGA_OBJC_API_IMPL_CATCH(error, 0.0) -#define OGA_OBJC_API_IMPL_CATCH_RETURNING_INT (error) OGA_OBJC_API_IMPL_CATCH(error, 0) +#define OGA_OBJC_API_IMPL_CATCH_RETURNING_DOUBLE (error) OGA_OBJC_API_IMPL_CATCH(error, double(0.0)) +#define OGA_OBJC_API_IMPL_CATCH_RETURNING_INT32_T (error) OGA_OBJC_API_IMPL_CATCH(error, int32_t(0)) #define OGA_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) OGA_OBJC_API_IMPL_CATCH(error, nil) +#define OGA_OBJC_API_IMPL_CATCH_RETURNING_SIZE_T (error) OGA_OBJC_API_IMPL_CATCH(error, size_t(-1)) NS_ASSUME_NONNULL_END diff --git a/src/objectivec/include/ort_genai_objc.h b/src/objectivec/include/ort_genai_objc.h index 5a6eb991f5..c3f23655c7 100644 --- a/src/objectivec/include/ort_genai_objc.h +++ b/src/objectivec/include/ort_genai_objc.h @@ -374,7 +374,7 @@ typedef NS_ENUM(NSInteger, OGAElementType) { * @param error Optional error information set if an error occurs. * @return The number of tokens in the generator */ -- (int)tokenCount:(NSError**)error; +- (size_t)tokenCount:(NSError**)error; /** * Rewinds the generator to the given length. diff --git a/src/objectivec/oga_generator.mm b/src/objectivec/oga_generator.mm index 5b2c51adf4..6a9b9b3669 100644 --- a/src/objectivec/oga_generator.mm +++ b/src/objectivec/oga_generator.mm @@ -67,11 +67,11 @@ - (BOOL)appendTokens:(NSArray*)tokens error:(NSError**)error { OGA_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) } -- (int)tokenCount:(NSError**)error { +- (size_t)tokenCount:(NSError**)error { try { return _generator->TokenCount(); } - OGA_OBJC_API_IMPL_CATCH_RETURNING_INT(error) + OGA_OBJC_API_IMPL_CATCH_RETURNING_SIZE_T(error) } - (BOOL)rewindTo:(size_t)newLength error:(NSError**)error { @@ -117,7 +117,7 @@ - (size_t)sequenceCountAtIndex:(size_t)index error:(NSError**)error { try { return _generator->GetSequenceCount(index); } - OGA_OBJC_API_IMPL_CATCH(error, size_t(-1)) + OGA_OBJC_API_IMPL_CATCH_RETURNING_SIZE_T(error) } + (void)shutdown { diff --git a/src/objectivec/oga_sequences.mm b/src/objectivec/oga_sequences.mm index 7dd717b360..6ff7c74250 100644 --- a/src/objectivec/oga_sequences.mm +++ b/src/objectivec/oga_sequences.mm @@ -30,7 +30,7 @@ - (size_t)getCountWithError:(NSError**)error { try { return _sequences->Count(); } - OGA_OBJC_API_IMPL_CATCH(error, size_t(-1)) + OGA_OBJC_API_IMPL_CATCH_RETURNING_SIZE_T(error) } - (nullable const int32_t*)sequenceDataAtIndex:(size_t)index error:(NSError**)error { @@ -44,7 +44,7 @@ - (size_t)sequenceCountAtIndex:(size_t)index error:(NSError**)error { try { return _sequences->SequenceCount(index); } - OGA_OBJC_API_IMPL_CATCH(error, size_t(-1)) + OGA_OBJC_API_IMPL_CATCH_RETURNING_SIZE_T(error) } - (OgaSequences&)CXXAPIOgaSequences { diff --git a/src/objectivec/oga_tokenizer.mm b/src/objectivec/oga_tokenizer.mm index 10f67f7b13..a3eb27ffd2 100644 --- a/src/objectivec/oga_tokenizer.mm +++ b/src/objectivec/oga_tokenizer.mm @@ -25,7 +25,7 @@ - (int32_t)getBosTokenId:(NSError**)error { try { return _tokenizer->GetBosTokenId(); } - OGA_OBJC_API_IMPL_CATCH_RETURNING_INT(error) + OGA_OBJC_API_IMPL_CATCH_RETURNING_INT32_T(error) } - (nullable NSArray*)getEosTokenIds:(NSError**)error { @@ -47,7 +47,7 @@ - (int32_t)getPadTokenId:(NSError**)error { try { return _tokenizer->GetPadTokenId(); } - OGA_OBJC_API_IMPL_CATCH_RETURNING_INT(error) + OGA_OBJC_API_IMPL_CATCH_RETURNING_INT32_T(error) } - (nullable OGASequences*)encode:(NSString*)str error:(NSError**)error { diff --git a/src/ort_genai.h b/src/ort_genai.h index 270421b6bd..d6ea59d3f1 100644 --- a/src/ort_genai.h +++ b/src/ort_genai.h @@ -480,10 +480,8 @@ struct OgaGenerator : OgaAbstract { } #endif - int32_t TokenCount() const { - int32_t count; - OgaCheckResult(OgaGenerator_TokenCount(this, &count)); - return count; + size_t TokenCount() const { + return OgaGenerator_TokenCount(this); } void GenerateNextToken() { diff --git a/src/ort_genai_c.cpp b/src/ort_genai_c.cpp index 86738d917a..41c70cdb3c 100644 --- a/src/ort_genai_c.cpp +++ b/src/ort_genai_c.cpp @@ -464,11 +464,8 @@ OgaResult* OGA_API_CALL OgaGenerator_AppendTokens(OgaGenerator* generator, const OGA_CATCH } -OgaResult* OGA_API_CALL OgaGenerator_TokenCount(const OgaGenerator* generator, int32_t* count) { - OGA_TRY - *count = generator->TokenCount(); - return nullptr; - OGA_CATCH +size_t OGA_API_CALL OgaGenerator_TokenCount(const OgaGenerator* generator) { + return generator->TokenCount(); } OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken(OgaGenerator* generator) { diff --git a/src/ort_genai_c.h b/src/ort_genai_c.h index 820bc17459..de7d4fa484 100644 --- a/src/ort_genai_c.h +++ b/src/ort_genai_c.h @@ -522,10 +522,9 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_AppendTokens(OgaGenerator* gener /** * \brief Returns the number of tokens in the generator * \param[in] generator The generator containing the appended tokens. - * \param[out] count The number of tokens that have been added. - * \return OgaResult containing the error message if the getting of the number of tokens failed. + * \return The number of tokens that have been added. */ -OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_TokenCount(const OgaGenerator* generator, int32_t* count); +OGA_EXPORT size_t OGA_API_CALL OgaGenerator_TokenCount(const OgaGenerator* generator); /** * \brief Computes the logits from the model based on the input ids and the past state. The computed logits are stored in the generator. diff --git a/src/python/python.cpp b/src/python/python.cpp index e56f9f1620..3626006b32 100644 --- a/src/python/python.cpp +++ b/src/python/python.cpp @@ -258,7 +258,7 @@ struct PyGenerator { generator_->AppendTokens(ToSpan(tokens)); } - int32_t TokenCount() const { + size_t TokenCount() const { return generator_->TokenCount(); } diff --git a/src/search.h b/src/search.h index 98c0e91a73..8921a23a8f 100644 --- a/src/search.h +++ b/src/search.h @@ -19,7 +19,6 @@ struct Search : LeakChecked { virtual DeviceSpan GetLogits() const = 0; virtual void SetLogits(DeviceSpan logits) = 0; virtual bool IsDone() const = 0; - virtual void ResetDone() = 0; virtual void SelectTop() = 0; virtual void SampleTopP(float /*p*/, float /*temperature*/) { assert(false); } @@ -45,7 +44,7 @@ struct Search_Cpu : Search { DeviceSpan GetSequenceLengths() override { return sequence_lengths_; } bool IsDone() const override { return done_; } - void ResetDone() override; + void ResetDone(); DeviceSpan GetLogits() const override; void SetLogits(DeviceSpan logits) override; @@ -79,7 +78,6 @@ struct GreedySearch_Cpu : Search_Cpu { // Used by continuous decoding search. void AppendTokens(DeviceSpan& next_tokens) override; void RewindTo(size_t index) override; - void ResetDone() override; protected: void SetNextToken(size_t batch_id, int32_t token); From 4461a6ef52ec9588853f5ebca87f56ba45a9f4e0 Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Mon, 26 Jan 2026 23:31:49 +0000 Subject: [PATCH 35/38] Add back missing definition --- src/search.h | 1 + 1 file changed, 1 insertion(+) diff --git a/src/search.h b/src/search.h index 8921a23a8f..b5d19de908 100644 --- a/src/search.h +++ b/src/search.h @@ -76,6 +76,7 @@ struct GreedySearch_Cpu : Search_Cpu { void SampleTopKTopP(int /*k*/, float /*p*/, float /*temperature*/) override; // Used by continuous decoding search. + void ResetDone(); void AppendTokens(DeviceSpan& next_tokens) override; void RewindTo(size_t index) override; From 6913f9f72afecc5284999aed41b1c42f641df487 Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Mon, 26 Jan 2026 23:55:39 +0000 Subject: [PATCH 36/38] Pin transformers to be before v5 --- test/python/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/python/requirements.txt b/test/python/requirements.txt index 3a851be2e5..2d7c639bdf 100644 --- a/test/python/requirements.txt +++ b/test/python/requirements.txt @@ -7,5 +7,5 @@ sympy pytest onnx onnx_ir>=0.1.3 -transformers +transformers<5.0.0 huggingface_hub[cli] From 5a668f354e89a26e78630dc6eb9f7c432c40ef73 Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Tue, 27 Jan 2026 00:25:27 +0000 Subject: [PATCH 37/38] Cast token count from size_t to int in C# --- test/csharp/TestOnnxRuntimeGenAIAPI.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/csharp/TestOnnxRuntimeGenAIAPI.cs b/test/csharp/TestOnnxRuntimeGenAIAPI.cs index 71368409b4..4e6a05b5cd 100644 --- a/test/csharp/TestOnnxRuntimeGenAIAPI.cs +++ b/test/csharp/TestOnnxRuntimeGenAIAPI.cs @@ -163,7 +163,7 @@ public void TestGreedySearch() Assert.False(generator.IsDone()); Assert.Equal(generatorParams.GetSearchNumber("max_length"), maxLength); Assert.Equal(generatorParams.GetSearchBool("early_stopping"), true); - Assert.Equal(generator.TokenCount(), generator.GetSequence(0).Length); + Assert.Equal((int)generator.TokenCount(), generator.GetSequence(0).Length); while (!generator.IsDone()) { @@ -175,7 +175,7 @@ public void TestGreedySearch() var sequence = generator.GetSequence(i).ToArray(); var expectedSequence = expectedOutput.Skip((int)i * (int)maxLength).Take((int)maxLength); Assert.Equal(expectedSequence, sequence); - Assert.Equal(generator.TokenCount(), generator.GetSequence(i).Length); + Assert.Equal((int)generator.TokenCount(), generator.GetSequence(i).Length); } } } From c2a45e9686d207201c5a4a1f90b20a1e18ce0c6f Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Tue, 27 Jan 2026 18:14:14 +0000 Subject: [PATCH 38/38] Simplify getting chunk size value --- src/generators.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/generators.cpp b/src/generators.cpp index 412741b378..0473c51a07 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -292,7 +292,7 @@ double GeneratorParams::GetSearchNumber(std::string_view name) const { if (name == "batch_size") { return static_cast(search.batch_size); } else if (name == "chunk_size") { - return search.chunk_size.has_value() ? static_cast(search.chunk_size.value()) : 0.0; + return static_cast(search.chunk_size.value_or(0)); } else if (name == "diversity_penalty") { return search.diversity_penalty; } else if (name == "length_penalty") {