diff --git a/README.md b/README.md
index 2fd0d344f5..29313d8feb 100644
--- a/README.md
+++ b/README.md
@@ -16,7 +16,7 @@ See documentation at the [ONNX Runtime website](https://onnxruntime.ai/docs/gena
| Support matrix | Supported now | Under development | On the roadmap|
| -------------- | ------------- | ----------------- | -------------- |
-| Model architectures | ChatGLMDeepSeekErnieFaraGemmaGPTOSSGraniteLlamaMistralNemotronOLMoPhiPhi3VPhi4MMQwenQwen-2.5VLSmolLM3Whisper| Stable diffusion ||
+| Model architectures | AMD OLMo
ChatGLM
DeepSeek
ERNIE 4.5
Fara
Gemma
gpt-oss
Granite
Llama
Mistral
Nemotron
Phi (language + vision)
Qwen (language + vision)
SmolLM3
Whisper | Stable diffusion | Multi-modal models |
| API| Python
C#
C/C++
Java ^ | Objective-C ||
| O/S | Linux
Windows
Mac
Android || iOS |||
| Architecture | x86
x64
arm64 ||||
@@ -78,10 +78,8 @@ See [installation instructions](https://onnxruntime.ai/docs/genai/howto/install)
try:
generator.append_tokens(input_tokens)
- while True:
+ while not generator.is_done():
generator.generate_next_token()
- if generator.is_done():
- break
new_token = generator.get_next_tokens()[0]
print(tokenizer_stream.decode(new_token), end='', flush=True)
except KeyboardInterrupt:
@@ -115,13 +113,13 @@ Windows:
pip list | findstr "onnxruntime-genai"
```
-Checkout the version of the examples that correspond to that release.
+Then, check out the version of the examples that corresponds to that release.
```bash
# Clone the repo
git clone https://github.com/microsoft/onnxruntime-genai.git && cd onnxruntime-genai
# Checkout the branch for the version you are using
-git checkout v0.11.4
+git checkout v0.11.5
cd examples
```
@@ -145,30 +143,11 @@ Navigate to the examples folder in the main branch.
cd examples
```
-## Breaking API changes
+To install the nightly Python build:
-### v0.11.0
-
-Between `v0.11.0` and `v0.10.1`, there is a breaking API usage change to improve model quality during multi-turn conversations.
-
-Previously, the decoding loop could be written as follows.
-
-```
-while not IsDone():
- GenerateToken()
- GetLastToken()
- PrintLastToken()
-```
-
-In 0.11.0, the decoding loop should now be written as follows.
-
-```
-while True:
- GenerateToken()
- if IsDone():
- break
- GetLastToken()
- PrintLastToken()
+```bash
+# Change onnxruntime-genai to the Python package you want to install
+pip install --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/ onnxruntime-genai
```
## Roadmap
diff --git a/examples/c/src/model_chat.cpp b/examples/c/src/model_chat.cpp
index d8d6b7a882..4efe3a79bf 100644
--- a/examples/c/src/model_chat.cpp
+++ b/examples/c/src/model_chat.cpp
@@ -91,7 +91,7 @@ void CXX_API(const char* model_path, const char* execution_provider) {
const auto current_token_count = generator->GetSequenceCount(0);
try {
- while (true) {
+ while (!generator->IsDone()) {
generator->GenerateNextToken();
if (is_first_token) {
@@ -99,12 +99,7 @@ void CXX_API(const char* model_path, const char* execution_provider) {
is_first_token = false;
}
- if (generator->IsDone()) {
- break;
- }
-
- const auto num_tokens = generator->GetSequenceCount(0);
- const auto new_token = generator->GetSequenceData(0)[num_tokens - 1];
+ const auto new_token = generator->GetNextTokens()[0];
std::cout << tokenizer_stream->Decode(new_token) << std::flush;
}
} catch (const std::exception& e) {
diff --git a/examples/c/src/model_qa.cpp b/examples/c/src/model_qa.cpp
index 9c06659057..b6b4a3c229 100644
--- a/examples/c/src/model_qa.cpp
+++ b/examples/c/src/model_qa.cpp
@@ -82,7 +82,7 @@ void CXX_API(const char* model_path, const char* execution_provider) {
generator->AppendTokenSequences(*sequences);
try {
- while (true) {
+ while (!generator->IsDone()) {
generator->GenerateNextToken();
if (is_first_token) {
@@ -90,12 +90,7 @@ void CXX_API(const char* model_path, const char* execution_provider) {
is_first_token = false;
}
- if (generator->IsDone()) {
- break;
- }
-
- const auto num_tokens = generator->GetSequenceCount(0);
- const auto new_token = generator->GetSequenceData(0)[num_tokens - 1];
+ const auto new_token = generator->GetNextTokens()[0];
std::cout << tokenizer_stream->Decode(new_token) << std::flush;
}
} catch (const std::exception& e) {
diff --git a/examples/c/src/model_vision.cpp b/examples/c/src/model_vision.cpp
index 1be5de0470..b902ecee2c 100644
--- a/examples/c/src/model_vision.cpp
+++ b/examples/c/src/model_vision.cpp
@@ -90,15 +90,9 @@ void CXX_API(const char* model_path, const char* execution_provider) {
auto generator = OgaGenerator::Create(*model, *params);
generator->SetInputs(*input_tensors);
- while (true) {
+ while (!generator->IsDone()) {
generator->GenerateNextToken();
-
- if (generator->IsDone()) {
- break;
- }
-
- const auto num_tokens = generator->GetSequenceCount(0);
- const auto new_token = generator->GetSequenceData(0)[num_tokens - 1];
+ const auto new_token = generator->GetNextTokens()[0];
std::cout << stream->Decode(new_token) << std::flush;
}
diff --git a/examples/c/src/phi4-mm.cpp b/examples/c/src/phi4-mm.cpp
index ab0be4d77c..e1deb781d4 100644
--- a/examples/c/src/phi4-mm.cpp
+++ b/examples/c/src/phi4-mm.cpp
@@ -101,15 +101,9 @@ void CXX_API(const char* model_path, const char* execution_provider) {
auto generator = OgaGenerator::Create(*model, *params);
generator->SetInputs(*input_tensors);
- while (true) {
+ while (!generator->IsDone()) {
generator->GenerateNextToken();
-
- if (generator->IsDone()) {
- break;
- }
-
- const auto num_tokens = generator->GetSequenceCount(0);
- const auto new_token = generator->GetSequenceData(0)[num_tokens - 1];
+ const auto new_token = generator->GetNextTokens()[0];
std::cout << stream->Decode(new_token) << std::flush;
}
diff --git a/examples/c/src/whisper.cpp b/examples/c/src/whisper.cpp
index 48a657972c..32e8c9e029 100644
--- a/examples/c/src/whisper.cpp
+++ b/examples/c/src/whisper.cpp
@@ -57,11 +57,8 @@ void CXX_API(const char* model_path, int32_t num_beams) {
auto generator = OgaGenerator::Create(*model, *params);
generator->SetInputs(*inputs);
- while (true) {
+ while (!generator->IsDone()) {
generator->GenerateNextToken();
- if (generator->IsDone()) {
- break;
- }
}
for (size_t i = 0; i < static_cast(num_beams * batch_size); ++i) {
diff --git a/examples/csharp/HelloPhi/Program.cs b/examples/csharp/HelloPhi/Program.cs
index c23ab9907c..e2f64dfc8f 100644
--- a/examples/csharp/HelloPhi/Program.cs
+++ b/examples/csharp/HelloPhi/Program.cs
@@ -127,13 +127,9 @@ static string GetPrompt(bool interactive)
using var generator = new Generator(model, generatorParams);
generator.AppendTokenSequences(sequences);
var watch = System.Diagnostics.Stopwatch.StartNew();
- while (true)
+ while (!generator.IsDone())
{
generator.GenerateNextToken();
- if (generator.IsDone())
- {
- break;
- }
}
var outputSequence = generator.GetSequence(0);
@@ -155,13 +151,9 @@ static string GetPrompt(bool interactive)
using var generator = new Generator(model, generatorParams);
generator.AppendTokenSequences(sequences);
var watch = System.Diagnostics.Stopwatch.StartNew();
- while (true)
+ while (!generator.IsDone())
{
generator.GenerateNextToken();
- if (generator.IsDone())
- {
- break;
- }
Console.Write(tokenizerStream.Decode(generator.GetNextTokens()[0]));
}
Console.WriteLine();
@@ -196,13 +188,9 @@ static string GetPrompt(bool interactive)
var sequences = tokenizer.Encode(tokenizer.ApplyChatTemplate("", messages, "", true));
var watch = System.Diagnostics.Stopwatch.StartNew();
generator.AppendTokenSequences(sequences);
- while (true)
+ while (!generator.IsDone())
{
generator.GenerateNextToken();
- if (generator.IsDone())
- {
- break;
- }
Console.Write(tokenizerStream.Decode(generator.GetNextTokens()[0]));
}
Console.WriteLine();
diff --git a/examples/csharp/HelloPhi3V/Program.cs b/examples/csharp/HelloPhi3V/Program.cs
index 09e1038bd5..1d32cac199 100644
--- a/examples/csharp/HelloPhi3V/Program.cs
+++ b/examples/csharp/HelloPhi3V/Program.cs
@@ -168,13 +168,9 @@ void PrintUsage()
using var generator = new Generator(model, generatorParams);
generator.SetInputs(inputTensors);
var watch = System.Diagnostics.Stopwatch.StartNew();
- while (true)
+ while (!generator.IsDone())
{
generator.GenerateNextToken();
- if (generator.IsDone())
- {
- break;
- }
Console.Write(stream.Decode(generator.GetNextTokens()[0]));
}
watch.Stop();
diff --git a/examples/csharp/HelloPhi4MM/Program.cs b/examples/csharp/HelloPhi4MM/Program.cs
index bcec8d714f..ce0ddf359d 100644
--- a/examples/csharp/HelloPhi4MM/Program.cs
+++ b/examples/csharp/HelloPhi4MM/Program.cs
@@ -215,13 +215,9 @@ void PrintUsage()
using var generator = new Generator(model, generatorParams);
generator.SetInputs(inputTensors);
var watch = System.Diagnostics.Stopwatch.StartNew();
- while (true)
+ while (!generator.IsDone())
{
generator.GenerateNextToken();
- if (generator.IsDone())
- {
- break;
- }
Console.Write(stream.Decode(generator.GetNextTokens()[0]));
}
watch.Stop();
diff --git a/examples/python/awq-quantized-model.py b/examples/python/awq-quantized-model.py
index fd9991b3c7..465935e264 100644
--- a/examples/python/awq-quantized-model.py
+++ b/examples/python/awq-quantized-model.py
@@ -108,11 +108,8 @@ def run_model(args):
print("Output: ", end="", flush=True)
try:
- while True:
+ while not generator.is_done():
generator.generate_next_token()
- if generator.is_done():
- break
-
new_token = generator.get_next_tokens()[0]
print(tokenizer_stream.decode(new_token), end="", flush=True)
except KeyboardInterrupt:
diff --git a/examples/python/model-chat.py b/examples/python/model-chat.py
index 44c9add4ff..ca9746b16a 100644
--- a/examples/python/model-chat.py
+++ b/examples/python/model-chat.py
@@ -216,16 +216,13 @@ def main(args):
print("Output: ", end="", flush=True)
try:
- while True:
+ while not generator.is_done():
generator.generate_next_token()
if args.timings:
if first:
first_token_timestamp = time.time()
first = False
- if generator.is_done():
- break
-
new_token = generator.get_next_tokens()[0]
print(tokenizer_stream.decode(new_token), end="", flush=True)
if args.timings:
diff --git a/examples/python/model-generate.py b/examples/python/model-generate.py
index 649d393c4a..86be56922b 100644
--- a/examples/python/model-generate.py
+++ b/examples/python/model-generate.py
@@ -99,10 +99,8 @@ def main(args):
if args.verbose:
print("Generating tokens ...\n")
start_time = time.time()
- while True:
+ while not generator.is_done():
generator.generate_next_token()
- if generator.is_done():
- break
run_time = time.time() - start_time
for i in range(len(prompts)):
diff --git a/examples/python/model-qa.py b/examples/python/model-qa.py
index cfd31cd946..a9b2cf7802 100644
--- a/examples/python/model-qa.py
+++ b/examples/python/model-qa.py
@@ -207,16 +207,13 @@ def main(args):
print("Output: ", end="", flush=True)
try:
- while True:
+ while not generator.is_done():
generator.generate_next_token()
if args.timings:
if first:
first_token_timestamp = time.time()
first = False
- if generator.is_done():
- break
-
new_token = generator.get_next_tokens()[0]
print(tokenizer_stream.decode(new_token), end="", flush=True)
if args.timings:
diff --git a/examples/python/model-vision.py b/examples/python/model-vision.py
index acd95f1ff1..e226697e15 100644
--- a/examples/python/model-vision.py
+++ b/examples/python/model-vision.py
@@ -150,11 +150,8 @@ def run(args: argparse.Namespace):
generator.set_inputs(inputs)
start_time = time.time()
- while True:
+ while not generator.is_done():
generator.generate_next_token()
- if generator.is_done():
- break
-
new_token = generator.get_next_tokens()[0]
print(stream.decode(new_token), end="", flush=True)
diff --git a/examples/python/phi3-qa.py b/examples/python/phi3-qa.py
index e0296d1905..645954fee3 100644
--- a/examples/python/phi3-qa.py
+++ b/examples/python/phi3-qa.py
@@ -75,16 +75,13 @@ def main(args):
print("Output: ", end="", flush=True)
try:
- while True:
+ while not generator.is_done():
generator.generate_next_token()
if args.timings:
if first:
first_token_timestamp = time.time()
first = False
- if generator.is_done():
- break
-
new_token = generator.get_next_tokens()[0]
print(tokenizer_stream.decode(new_token), end="", flush=True)
if args.timings:
diff --git a/examples/python/phi4-mm.py b/examples/python/phi4-mm.py
index 80a1413942..c9aa155cd4 100644
--- a/examples/python/phi4-mm.py
+++ b/examples/python/phi4-mm.py
@@ -145,11 +145,8 @@ def run(args: argparse.Namespace):
generator.set_inputs(inputs)
start_time = time.time()
- while True:
+ while not generator.is_done():
generator.generate_next_token()
- if generator.is_done():
- break
-
new_token = generator.get_next_tokens()[0]
print(tokenizer_stream.decode(new_token), end="", flush=True)
diff --git a/src/config.h b/src/config.h
index a9a8073745..51358b8192 100644
--- a/src/config.h
+++ b/src/config.h
@@ -301,7 +301,7 @@ struct Config {
int top_k{50}; // Number of highest probability vocabulary tokens to keep for top-k-filtering that will be used by default in the generate method of the model.
float top_p{}; // If set to float >0 and <1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.
float temperature{1.0f}; // Temperature to control during generation. Default is 1.0.
- bool early_stopping{true}; // Whether to stop the beam search when at least num_beams sentences are finished per batch or not.
+ bool early_stopping{true}; // Whether to stop the beam search when at least num_beams sentences are finished per batch or not.
int no_repeat_ngram_size{}; // Unused param
float diversity_penalty{}; // Unused param
float length_penalty{1.0f}; // Exponential penalty to the length that is used with beam-based generation. length_penalty > 0.0 promotes longer sequences, while length_penalty < 0.0 encourages shorter sequences.
diff --git a/src/csharp/Generator.cs b/src/csharp/Generator.cs
index 0c7eb31d81..94f83f0798 100644
--- a/src/csharp/Generator.cs
+++ b/src/csharp/Generator.cs
@@ -46,6 +46,15 @@ public void AppendTokenSequences(Sequences sequences)
Result.VerifySuccess(NativeMethods.OgaGenerator_AppendTokenSequences(_generatorHandle, sequences.Handle));
}
+ ///
+ /// Gets the number of tokens in the generator
+ ///
+ /// The token count
+ public ulong TokenCount()
+ {
+ return NativeMethods.OgaGenerator_TokenCount(_generatorHandle).ToUInt64();
+ }
+
public void GenerateNextToken()
{
Result.VerifySuccess(NativeMethods.OgaGenerator_GenerateNextToken(_generatorHandle));
diff --git a/src/csharp/GeneratorParams.cs b/src/csharp/GeneratorParams.cs
index cc23eba0f9..4eb7e01fa7 100644
--- a/src/csharp/GeneratorParams.cs
+++ b/src/csharp/GeneratorParams.cs
@@ -29,14 +29,21 @@ public void SetSearchOption(string searchOption, bool value)
Result.VerifySuccess(NativeMethods.OgaGeneratorParamsSetSearchBool(_generatorParamsHandle, StringUtils.ToUtf8(searchOption), value));
}
- public void TryGraphCaptureWithMaxBatchSize(int maxBatchSize)
+ public void SetGuidance(string type, string data, bool enableFFTokens = false)
{
- Console.WriteLine("TryGraphCaptureWithMaxBatchSize is deprecated and will be removed in a future release.");
+ Result.VerifySuccess(NativeMethods.OgaGeneratorParamsSetGuidance(_generatorParamsHandle, StringUtils.ToUtf8(type), StringUtils.ToUtf8(data), enableFFTokens));
}
- public void SetGuidance(string type, string data, bool enableFFTokens = false)
+ public double GetSearchNumber(string searchOption)
{
- Result.VerifySuccess(NativeMethods.OgaGeneratorParamsSetGuidance(_generatorParamsHandle, StringUtils.ToUtf8(type), StringUtils.ToUtf8(data), enableFFTokens));
+ Result.VerifySuccess(NativeMethods.OgaGeneratorParamsGetSearchNumber(_generatorParamsHandle, StringUtils.ToUtf8(searchOption), out double value));
+ return value;
+ }
+
+ public bool GetSearchBool(string searchOption)
+ {
+ Result.VerifySuccess(NativeMethods.OgaGeneratorParamsGetSearchBool(_generatorParamsHandle, StringUtils.ToUtf8(searchOption), out bool value));
+ return value;
}
~GeneratorParams()
diff --git a/src/csharp/NativeMethods.cs b/src/csharp/NativeMethods.cs
index 09abc65bac..1ba9aac906 100644
--- a/src/csharp/NativeMethods.cs
+++ b/src/csharp/NativeMethods.cs
@@ -109,18 +109,27 @@ internal class NativeLib
public static extern IntPtr /* OgaResult* */ OgaGeneratorParamsSetSearchBool(IntPtr /* OgaGeneratorParams* */ generatorParams,
byte[] /* const char* */ searchOption,
bool value);
-
- [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
- public static extern IntPtr /* OgaResult* */ OgaCreateGenerator(IntPtr /* const OgaModel* */ model,
- IntPtr /* const OgaGeneratorParams* */ generatorParams,
- out IntPtr /* OgaGenerator** */ generator);
-
[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern IntPtr /* OgaResult* */ OgaGeneratorParamsSetGuidance(IntPtr /* OgaGeneratorParams* */ generatorParams,
byte[] /* const char* */ type,
byte[] /* const char* */ data,
bool /* boolean */ enable_ff_tokens);
+ [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
+ public static extern IntPtr /* OgaResult* */ OgaGeneratorParamsGetSearchNumber(IntPtr /* OgaGeneratorParams* */ generatorParams,
+ byte[] /* const char* */ searchOption,
+ out double /* const double* */ value);
+
+ [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
+ public static extern IntPtr /* OgaResult* */ OgaGeneratorParamsGetSearchBool(IntPtr /* OgaGeneratorParams* */ generatorParams,
+ byte[] /* const char* */ searchOption,
+ out bool /* const bool* */ value);
+
+ [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
+ public static extern IntPtr /* OgaResult* */ OgaCreateGenerator(IntPtr /* const OgaModel* */ model,
+ IntPtr /* const OgaGeneratorParams* */ generatorParams,
+ out IntPtr /* OgaGenerator** */ generator);
+
[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern void OgaDestroyGenerator(IntPtr /* OgaGenerator* */ generator);
@@ -157,6 +166,10 @@ internal class NativeLib
public static extern IntPtr /* OgaResult* */ OgaGenerator_AppendTokenSequences(IntPtr /* OgaGenerator* */ generator,
IntPtr /* const OgaSequences* */ sequences);
+ // This function is used to get the number of tokens in the generator.
+ [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
+ public static extern UIntPtr OgaGenerator_TokenCount(IntPtr /* const OgaGenerator* */ generator);
+
// This function is used to rewind the generator to the given newLength.
[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
diff --git a/src/cuda/search_cuda.cpp b/src/cuda/search_cuda.cpp
index ad5a7c6c5a..b306d0f473 100644
--- a/src/cuda/search_cuda.cpp
+++ b/src/cuda/search_cuda.cpp
@@ -25,14 +25,13 @@ Search_Cuda::Search_Cuda(const GeneratorParams& params)
sequence_lengths_ = params.p_device->Allocate(batch_beam_size);
eos_seen_buffer_ = CudaMallocArray(batch_beam_size, &eos_seen_);
- cudaMemsetAsync(eos_seen_.data(), 0, eos_seen_.size_bytes(), GetStream());
+ done_cpu_ = CudaMallocHostArray(1);
eos_token_ids_ = params.p_device->Allocate(params.config.model.eos_token_id.size());
copy(std::span{params.config.model.eos_token_id}, eos_token_ids_.CpuSpan());
eos_token_ids_.CopyCpuToDevice();
- done_cpu_ = CudaMallocHostArray(1);
- *done_cpu_ = false;
+ ResetDone();
}
GreedySearch_Cuda::GreedySearch_Cuda(const GeneratorParams& params)
@@ -75,6 +74,11 @@ BeamSearch_Cuda::BeamSearch_Cuda(const GeneratorParams& params)
BeamSearch_Cuda::~BeamSearch_Cuda() = default;
+void Search_Cuda::ResetDone() {
+ *done_cpu_ = false;
+ cudaMemsetAsync(eos_seen_.data(), 0, eos_seen_.size_bytes(), GetStream());
+}
+
DeviceSpan Search_Cuda::GetLogits() const {
return next_token_scores_;
}
@@ -221,8 +225,7 @@ std::span Search_Cuda::GetScores() {
// Set user input tokens (batch_beam_size, sequence_length)
void GreedySearch_Cuda::AppendTokens(DeviceSpan& next_tokens) {
- cudaMemsetAsync(eos_seen_.data(), 0, eos_seen_.size_bytes(), GetStream());
- *done_cpu_ = false;
+ ResetDone();
auto next_tokens_gpu = next_tokens.Span();
cuda::Launch_AppendNextTokensToSequences(next_tokens_gpu, sequences_.GetSequences().Span(), params_->BatchBeamSize(), sequences_.GetSequenceLength(), sequences_.max_length_, GetStream());
@@ -235,8 +238,7 @@ void GreedySearch_Cuda::AppendTokens(DeviceSpan& next_tokens) {
return;
}
- cudaMemsetAsync(eos_seen_.data(), 0, eos_seen_.size_bytes(), GetStream());
- *done_cpu_ = false;
+ ResetDone();
}
void BeamSearch_Cuda::AppendTokens(DeviceSpan& next_tokens) {
@@ -248,8 +250,7 @@ void BeamSearch_Cuda::AppendTokens(DeviceSpan& next_tokens) {
}
void GreedySearch_Cuda::RewindTo(size_t index) {
- cudaMemsetAsync(eos_seen_.data(), 0, eos_seen_.size_bytes(), GetStream());
- *done_cpu_ = false;
+ ResetDone();
if (index > 0)
cuda::Launch_GetLastTokens(next_tokens_.data(), sequences_.GetSequences().Span().data(), static_cast(params_->BatchBeamSize()), static_cast(index), sequences_.max_length_, GetStream());
else
diff --git a/src/cuda/search_cuda.h b/src/cuda/search_cuda.h
index 6d0f03f73d..b5ead4c28c 100644
--- a/src/cuda/search_cuda.h
+++ b/src/cuda/search_cuda.h
@@ -19,6 +19,7 @@ struct Search_Cuda : Search {
cudaStreamSynchronize(GetStream());
return *done_cpu_;
} // TODO: Use an event
+ void ResetDone();
DeviceSpan GetLogits() const override;
void SetLogits(DeviceSpan logits) override;
diff --git a/src/generators.cpp b/src/generators.cpp
index d163eeacde..0473c51a07 100644
--- a/src/generators.cpp
+++ b/src/generators.cpp
@@ -288,6 +288,52 @@ bool GeneratorParams::IsPastPresentShareBufferEnabled(const std::string& model_t
(search.num_beams == 1 || model_type == "whisper");
}
+double GeneratorParams::GetSearchNumber(std::string_view name) const {
+ if (name == "batch_size") {
+ return static_cast(search.batch_size);
+ } else if (name == "chunk_size") {
+ return static_cast(search.chunk_size.value_or(0));
+ } else if (name == "diversity_penalty") {
+ return search.diversity_penalty;
+ } else if (name == "length_penalty") {
+ return search.length_penalty;
+ } else if (name == "max_length") {
+ return static_cast(search.max_length);
+ } else if (name == "min_length") {
+ return static_cast(search.min_length);
+ } else if (name == "no_repeat_ngram_size") {
+ return static_cast(search.no_repeat_ngram_size);
+ } else if (name == "num_beams") {
+ return static_cast(search.num_beams);
+ } else if (name == "num_return_sequences") {
+ return static_cast(search.num_return_sequences);
+ } else if (name == "random_seed") {
+ return static_cast(search.random_seed);
+ } else if (name == "repetition_penalty") {
+ return search.repetition_penalty;
+ } else if (name == "temperature") {
+ return search.temperature;
+ } else if (name == "top_k") {
+ return static_cast(search.top_k);
+ } else if (name == "top_p") {
+ return search.top_p;
+ } else {
+ throw std::runtime_error(std::string(name) + " is an invalid name for GetSearchNumber.");
+ }
+}
+
+bool GeneratorParams::GetSearchBool(std::string_view name) const {
+ if (name == "do_sample") {
+ return search.do_sample;
+ } else if (name == "early_stopping") {
+ return search.early_stopping;
+ } else if (name == "past_present_share_buffer") {
+ return search.past_present_share_buffer;
+ } else {
+ throw std::runtime_error(std::string(name) + " is an invalid name for GetSearchBool.");
+ }
+}
+
std::unique_ptr CreateGenerator(const Model& model, const GeneratorParams& params) {
return std::make_unique(model, params);
}
@@ -454,6 +500,10 @@ void Generator::SetRuntimeOption(const char* key, const char* value) {
state_->SetRunOption(key, value);
}
+size_t Generator::TokenCount() const {
+ return static_cast(search_->GetSequenceLength());
+}
+
bool Generator::IsDone() {
ThrowErrorIfSessionTerminated(state_->session_terminated_);
if (computed_logits_) {
diff --git a/src/generators.h b/src/generators.h
index 1d643e198b..ccfed8db44 100644
--- a/src/generators.h
+++ b/src/generators.h
@@ -74,6 +74,10 @@ struct GeneratorParams : std::enable_shared_from_this, LeakChec
const Config& config; // The model outlives the GeneratorParams
Config::Search search{config.search}; // Copy of the search parameters from the config
+ // Query the params to get the value set for a param
+ double GetSearchNumber(std::string_view name) const;
+ bool GetSearchBool(std::string_view name) const;
+
int max_batch_size{0};
bool use_graph_capture{};
bool use_multi_profile{};
@@ -95,6 +99,7 @@ struct Generator : LeakChecked {
Generator(const Model& model, const GeneratorParams& params);
bool IsDone();
+ size_t TokenCount() const;
void AppendTokens(cpu_span input_ids);
void GenerateNextToken();
void RewindToLength(size_t new_length); // Rewind state to new_length
diff --git a/src/java/src/main/java/ai/onnxruntime/genai/Generator.java b/src/java/src/main/java/ai/onnxruntime/genai/Generator.java
index 66489ed8df..13869c25ca 100644
--- a/src/java/src/main/java/ai/onnxruntime/genai/Generator.java
+++ b/src/java/src/main/java/ai/onnxruntime/genai/Generator.java
@@ -130,6 +130,19 @@ public void appendTokenSequences(Sequences sequences) throws GenAIException {
appendTokenSequences(nativeHandle, sequences.nativeHandle());
}
+ /**
+ * Returns the token count in the generator.
+ *
+ * @throws GenAIException If the call to the GenAI native API fails.
+ */
+ public long tokenCount() throws GenAIException {
+ if (nativeHandle == 0) {
+ throw new IllegalStateException("Instance has been freed and is invalid");
+ }
+
+ return tokenCount(nativeHandle);
+ }
+
/**
* Rewinds the generator to the given length. This is useful when the user wants to rewind the
* generator to a specific length and continue generating from that point.
@@ -280,6 +293,8 @@ private native void setModelInput(long nativeHandle, String inputName, long tens
private native void appendTokenSequences(long nativeHandle, long sequencesHandle)
throws GenAIException;
+ private native long tokenCount(long nativeHandle) throws GenAIException;
+
private native void rewindTo(long nativeHandle, long newLength) throws GenAIException;
private native void generateNextTokenNative(long nativeHandle) throws GenAIException;
diff --git a/src/java/src/main/java/ai/onnxruntime/genai/GeneratorParams.java b/src/java/src/main/java/ai/onnxruntime/genai/GeneratorParams.java
index 7bf8306f4a..c6fd3f4945 100644
--- a/src/java/src/main/java/ai/onnxruntime/genai/GeneratorParams.java
+++ b/src/java/src/main/java/ai/onnxruntime/genai/GeneratorParams.java
@@ -29,7 +29,7 @@ public GeneratorParams(Model model) throws GenAIException {
}
/**
- * Set seach option with double value.
+ * Set search option with double value.
*
* @param optionName The option name.
* @param value The option value.
@@ -58,6 +58,36 @@ public void setSearchOption(String optionName, boolean value) throws GenAIExcept
setSearchOptionBool(nativeHandle, optionName, value);
}
+ /**
+ * Get search option with numerical value.
+ *
+ * @param optionName The option name.
+ * @return The option value.
+ * @throws GenAIException If the call to the GenAI native API fails.
+ */
+ public double getSearchNumber(String optionName) throws GenAIException {
+ if (nativeHandle == 0) {
+ throw new IllegalStateException("Instance has been freed and is invalid");
+ }
+
+ return getSearchNumber(nativeHandle, optionName);
+ }
+
+ /**
+ * Get search option with boolean value.
+ *
+ * @param optionName The option name.
+ * @return The option value.
+ * @throws GenAIException If the call to the GenAI native API fails.
+ */
+ public boolean getSearchBool(String optionName) throws GenAIException {
+ if (nativeHandle == 0) {
+ throw new IllegalStateException("Instance has been freed and is invalid");
+ }
+
+ return getSearchBool(nativeHandle, optionName);
+ }
+
@Override
public void close() {
if (nativeHandle != 0) {
@@ -87,4 +117,8 @@ private native void setSearchOptionNumber(long nativeHandle, String optionName,
private native void setSearchOptionBool(long nativeHandle, String optionName, boolean value)
throws GenAIException;
+
+ private native double getSearchNumber(long nativeHandle, String optionName) throws GenAIException;
+
+ private native boolean getSearchBool(long nativeHandle, String optionName) throws GenAIException;
}
diff --git a/src/java/src/main/native/ai_onnxruntime_genai_Generator.cpp b/src/java/src/main/native/ai_onnxruntime_genai_Generator.cpp
index c75e397917..5327683ecb 100644
--- a/src/java/src/main/native/ai_onnxruntime_genai_Generator.cpp
+++ b/src/java/src/main/native/ai_onnxruntime_genai_Generator.cpp
@@ -67,6 +67,13 @@ Java_ai_onnxruntime_genai_Generator_appendTokens(JNIEnv* env, jobject thiz, jlon
env->ReleaseIntArrayElements(token_ids, tokens, JNI_ABORT);
}
+JNIEXPORT jlong JNICALL
+Java_ai_onnxruntime_genai_Generator_tokenCount(JNIEnv* env, jobject thiz, jlong native_handle) {
+ OgaGenerator* generator = reinterpret_cast(native_handle);
+ size_t count = OgaGenerator_TokenCount(generator);
+ return static_cast(count);
+}
+
JNIEXPORT jboolean JNICALL
Java_ai_onnxruntime_genai_Generator_isDone(JNIEnv* env, jobject thiz, jlong native_handle) {
return OgaGenerator_IsDone(reinterpret_cast(native_handle));
diff --git a/src/java/src/main/native/ai_onnxruntime_genai_GeneratorParams.cpp b/src/java/src/main/native/ai_onnxruntime_genai_GeneratorParams.cpp
index 3664c0e3e8..907ab93040 100644
--- a/src/java/src/main/native/ai_onnxruntime_genai_GeneratorParams.cpp
+++ b/src/java/src/main/native/ai_onnxruntime_genai_GeneratorParams.cpp
@@ -43,3 +43,23 @@ Java_ai_onnxruntime_genai_GeneratorParams_setSearchOptionBool(JNIEnv* env, jobje
ThrowIfError(env, OgaGeneratorParamsSetSearchBool(generator_params, name, value));
}
+
+JNIEXPORT jdouble JNICALL
+Java_ai_onnxruntime_genai_GeneratorParams_getSearchNumber(JNIEnv* env, jobject thiz, jlong native_handle, jstring option_name) {
+ OgaGeneratorParams* generator_params = reinterpret_cast(native_handle);
+ CString name{env, option_name};
+ double value;
+
+ ThrowIfError(env, OgaGeneratorParamsGetSearchNumber(generator_params, name, &value));
+ return static_cast(value);
+}
+
+JNIEXPORT jboolean JNICALL
+Java_ai_onnxruntime_genai_GeneratorParams_getSearchBool(JNIEnv* env, jobject thiz, jlong native_handle, jstring option_name) {
+ OgaGeneratorParams* generator_params = reinterpret_cast(native_handle);
+ CString name{env, option_name};
+ bool value;
+
+ ThrowIfError(env, OgaGeneratorParamsGetSearchBool(generator_params, name, &value));
+ return static_cast(value);
+}
diff --git a/src/java/src/test/java/ai/onnxruntime/genai/GenerationTest.java b/src/java/src/test/java/ai/onnxruntime/genai/GenerationTest.java
index beb0a12ad9..bc25395a1f 100644
--- a/src/java/src/test/java/ai/onnxruntime/genai/GenerationTest.java
+++ b/src/java/src/test/java/ai/onnxruntime/genai/GenerationTest.java
@@ -138,6 +138,11 @@ public void testWithInputIds() throws GenAIException {
try (Generator generator = new Generator(model, params); ) {
generator.appendTokens(inputIDs);
+
+ assertEquals(params.getSearchNumber("max_length"), maxLength);
+ assertEquals(params.getSearchBool("early_stopping"), true);
+ assertEquals(generator.tokenCount(), 4);
+
while (!generator.isDone()) {
generator.generateNextToken();
}
@@ -148,6 +153,7 @@ public void testWithInputIds() throws GenAIException {
assertEquals(outputIds[j], expectedOutput[i * maxLength + j]);
}
}
+ assertEquals(generator.tokenCount(), generator.getSequence(0).length);
}
}
}
diff --git a/src/java/src/test/java/ai/onnxruntime/genai/GeneratorParamsTest.java b/src/java/src/test/java/ai/onnxruntime/genai/GeneratorParamsTest.java
index cd8698de38..944d3c4047 100644
--- a/src/java/src/test/java/ai/onnxruntime/genai/GeneratorParamsTest.java
+++ b/src/java/src/test/java/ai/onnxruntime/genai/GeneratorParamsTest.java
@@ -12,7 +12,7 @@
public class GeneratorParamsTest {
@Test
public void testValidSearchOption() throws GenAIException {
- // test setting an invalid search option throws a GenAIException
+ // test setting a valid search option
try (SimpleGenAI generator = new SimpleGenAI(TestUtils.tinyGpt2ModelPath());
GeneratorParams params = generator.createGeneratorParams(); ) {
params.setSearchOption("early_stopping", true); // boolean
diff --git a/src/objectivec/error_utils.h b/src/objectivec/error_utils.h
index 8c73faa971..43672ee465 100644
--- a/src/objectivec/error_utils.h
+++ b/src/objectivec/error_utils.h
@@ -23,7 +23,9 @@ void OGASaveExceptionToError(const std::exception& e, NSError** error);
}
#define OGA_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) OGA_OBJC_API_IMPL_CATCH(error, NO)
-
+#define OGA_OBJC_API_IMPL_CATCH_RETURNING_DOUBLE (error) OGA_OBJC_API_IMPL_CATCH(error, double(0.0))
+#define OGA_OBJC_API_IMPL_CATCH_RETURNING_INT32_T (error) OGA_OBJC_API_IMPL_CATCH(error, int32_t(0))
#define OGA_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) OGA_OBJC_API_IMPL_CATCH(error, nil)
+#define OGA_OBJC_API_IMPL_CATCH_RETURNING_SIZE_T (error) OGA_OBJC_API_IMPL_CATCH(error, size_t(-1))
NS_ASSUME_NONNULL_END
diff --git a/src/objectivec/include/ort_genai_objc.h b/src/objectivec/include/ort_genai_objc.h
index 2e3d07c836..c3f23655c7 100644
--- a/src/objectivec/include/ort_genai_objc.h
+++ b/src/objectivec/include/ort_genai_objc.h
@@ -146,6 +146,27 @@ typedef NS_ENUM(NSInteger, OGAElementType) {
- (nullable instancetype)initWithModel:(OGAModel*)model
error:(NSError**)error NS_DESIGNATED_INITIALIZER;
+/**
+ * Return the int representation of the BOS token.
+ *
+ * @return The BOS token id
+ */
+- (int32_t)getBosTokenId:(NSError**)error;
+
+/**
+ * Return the int representations of the array of EOS tokens.
+ *
+ * @return The array of EOS token ids
+ */
+- (nullable NSArray*)getEosTokenIds:(NSError**)error;
+
+/**
+ * Return the int representation of the PAD token.
+ *
+ * @return The PAD token id
+ */
+- (int32_t)getPadTokenId:(NSError**)error;
+
/**
* Encode text to sequences
*
@@ -273,6 +294,23 @@ typedef NS_ENUM(NSInteger, OGAElementType) {
- (BOOL)setSearchOption:(NSString*)key
boolValue:(BOOL)value
error:(NSError**)error;
+
+/**
+ * Get numerical value of option.
+ * @param key The option key.
+ * @param error Optional error information set if an error occurs.
+ * @return The option value.
+ */
+- (double)getSearchNumber:(NSString*)key
+ error:(NSError**)error;
+/**
+ * Get boolean value of option.
+ * @param key The option key.
+ * @param error Optional error information set if an error occurs.
+ * @return The option value.
+ */
+- (BOOL)getSearchBool:(NSString*)key
+ error:(NSError**)error;
@end
/**
@@ -331,6 +369,13 @@ typedef NS_ENUM(NSInteger, OGAElementType) {
*/
- (BOOL)appendTokens:(NSArray*)tokens error:(NSError**)error;
+/**
+ * Get the number of tokens in the generator.
+ * @param error Optional error information set if an error occurs.
+ * @return The number of tokens in the generator
+ */
+- (size_t)tokenCount:(NSError**)error;
+
/**
* Rewinds the generator to the given length.
* @param newLength The desired length in tokens after rewinding.
diff --git a/src/objectivec/oga_generator.mm b/src/objectivec/oga_generator.mm
index b5ca0e941c..6a9b9b3669 100644
--- a/src/objectivec/oga_generator.mm
+++ b/src/objectivec/oga_generator.mm
@@ -67,6 +67,13 @@ - (BOOL)appendTokens:(NSArray*)tokens error:(NSError**)error {
OGA_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error)
}
+- (size_t)tokenCount:(NSError**)error {
+ try {
+ return _generator->TokenCount();
+ }
+ OGA_OBJC_API_IMPL_CATCH_RETURNING_SIZE_T(error)
+}
+
- (BOOL)rewindTo:(size_t)newLength error:(NSError**)error {
try {
_generator->RewindTo(newLength);
@@ -110,7 +117,7 @@ - (size_t)sequenceCountAtIndex:(size_t)index error:(NSError**)error {
try {
return _generator->GetSequenceCount(index);
}
- OGA_OBJC_API_IMPL_CATCH(error, size_t(-1))
+ OGA_OBJC_API_IMPL_CATCH_RETURNING_SIZE_T(error)
}
+ (void)shutdown {
diff --git a/src/objectivec/oga_generator_params.mm b/src/objectivec/oga_generator_params.mm
index 1ae8719d41..25ecfd0386 100644
--- a/src/objectivec/oga_generator_params.mm
+++ b/src/objectivec/oga_generator_params.mm
@@ -37,6 +37,20 @@ - (BOOL)setSearchOption:(NSString*)key boolValue:(BOOL)value error:(NSError**)er
OGA_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error)
}
+- (double)getSearchNumber:(NSString*)key error:(NSError**)error {
+ try {
+ return _generatorParams->GetSearchNumber([key UTF8String]);
+ }
+ OGA_OBJC_API_IMPL_CATCH_RETURNING_DOUBLE(error)
+}
+
+- (BOOL)getSearchBool:(NSString*)key error:(NSError**)error {
+ try {
+ return _generatorParams->GetSearchBool([key UTF8String]);
+ }
+ OGA_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error)
+}
+
- (OgaGeneratorParams&)CXXAPIOgaGeneratorParams {
return *(_generatorParams.get());
}
diff --git a/src/objectivec/oga_sequences.mm b/src/objectivec/oga_sequences.mm
index 7dd717b360..6ff7c74250 100644
--- a/src/objectivec/oga_sequences.mm
+++ b/src/objectivec/oga_sequences.mm
@@ -30,7 +30,7 @@ - (size_t)getCountWithError:(NSError**)error {
try {
return _sequences->Count();
}
- OGA_OBJC_API_IMPL_CATCH(error, size_t(-1))
+ OGA_OBJC_API_IMPL_CATCH_RETURNING_SIZE_T(error)
}
- (nullable const int32_t*)sequenceDataAtIndex:(size_t)index error:(NSError**)error {
@@ -44,7 +44,7 @@ - (size_t)sequenceCountAtIndex:(size_t)index error:(NSError**)error {
try {
return _sequences->SequenceCount(index);
}
- OGA_OBJC_API_IMPL_CATCH(error, size_t(-1))
+ OGA_OBJC_API_IMPL_CATCH_RETURNING_SIZE_T(error)
}
- (OgaSequences&)CXXAPIOgaSequences {
diff --git a/src/objectivec/oga_tokenizer.mm b/src/objectivec/oga_tokenizer.mm
index c961faabc8..a3eb27ffd2 100644
--- a/src/objectivec/oga_tokenizer.mm
+++ b/src/objectivec/oga_tokenizer.mm
@@ -21,6 +21,35 @@ - (nullable instancetype)initWithModel:(OGAModel*)model error:(NSError**)error {
OGA_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
}
+- (int32_t)getBosTokenId:(NSError**)error {
+ try {
+ return _tokenizer->GetBosTokenId();
+ }
+ OGA_OBJC_API_IMPL_CATCH_RETURNING_INT32_T(error)
+}
+
+- (nullable NSArray*)getEosTokenIds:(NSError**)error {
+ try {
+ std::vector eos_ids = _tokenizer->GetEosTokenIds();
+ NSMutableArray* result = [NSMutableArray arrayWithCapacity:eos_ids.size()];
+ if (!result) {
+ return nil;
+ }
+ for (int32_t eos_id : eos_ids) {
+ [result addObject:@(eos_id)];
+ }
+ return result;
+ }
+ OGA_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
+}
+
+- (int32_t)getPadTokenId:(NSError**)error {
+ try {
+ return _tokenizer->GetPadTokenId();
+ }
+ OGA_OBJC_API_IMPL_CATCH_RETURNING_INT32_T(error)
+}
+
- (nullable OGASequences*)encode:(NSString*)str error:(NSError**)error {
OGASequences* sequences = [[OGASequences alloc] initWithError:error];
if (!sequences) {
diff --git a/src/ort_genai.h b/src/ort_genai.h
index dc38a1b967..d6ea59d3f1 100644
--- a/src/ort_genai.h
+++ b/src/ort_genai.h
@@ -324,12 +324,10 @@ struct OgaTokenizer : OgaAbstract {
}
#else
std::vector GetEosTokenIds() const {
- std::vector eos_ids;
const int32_t* eos_ids_ptr;
size_t count;
OgaCheckResult(OgaTokenizerGetEosTokenIds(this, &eos_ids_ptr, &count));
- eos_ids.assign(eos_ids_ptr, eos_ids_ptr + count);
- return eos_ids;
+ return std::vector(eos_ids_ptr, eos_ids_ptr + count);
}
#endif
@@ -426,14 +424,22 @@ struct OgaGeneratorParams : OgaAbstract {
OgaCheckResult(OgaGeneratorParamsSetSearchBool(this, name, value));
}
- void TryGraphCaptureWithMaxBatchSize(int /*max_batch_size*/) {
- printf("TryGraphCaptureWithMaxBatchSize is deprecated and will be removed in a future release\n");
- }
-
void SetGuidance(const char* type, const char* data, bool enable_ff_tokens = false) {
OgaCheckResult(OgaGeneratorParamsSetGuidance(this, type, data, enable_ff_tokens));
}
+ double GetSearchNumber(const char* name) const {
+ double value;
+ OgaCheckResult(OgaGeneratorParamsGetSearchNumber(this, name, &value));
+ return value;
+ }
+
+ bool GetSearchBool(const char* name) const {
+ bool value;
+ OgaCheckResult(OgaGeneratorParamsGetSearchBool(this, name, &value));
+ return value;
+ }
+
static void operator delete(void* p) { OgaDestroyGeneratorParams(reinterpret_cast(p)); }
};
@@ -448,6 +454,10 @@ struct OgaGenerator : OgaAbstract {
return OgaGenerator_IsDone(this);
}
+ bool IsSessionTerminated() const {
+ return OgaGenerator_IsSessionTerminated(this);
+ }
+
void SetModelInput(const char* name, OgaTensor& tensor) {
OgaCheckResult(OgaGenerator_SetModelInput(this, name, &tensor));
}
@@ -470,8 +480,8 @@ struct OgaGenerator : OgaAbstract {
}
#endif
- bool IsSessionTerminated() const {
- return OgaGenerator_IsSessionTerminated(this);
+ size_t TokenCount() const {
+ return OgaGenerator_TokenCount(this);
}
void GenerateNextToken() {
@@ -485,6 +495,13 @@ struct OgaGenerator : OgaAbstract {
OgaCheckResult(OgaGenerator_GetNextTokens(this, &out, &out_count));
return {out, out_count};
}
+#else
+ std::vector GetNextTokens() {
+ const int32_t* out;
+ size_t out_count;
+ OgaCheckResult(OgaGenerator_GetNextTokens(this, &out, &out_count));
+ return std::vector(out, out + out_count);
+ }
#endif
void RewindTo(size_t new_length) {
diff --git a/src/ort_genai_c.cpp b/src/ort_genai_c.cpp
index e18b424b0b..41c70cdb3c 100644
--- a/src/ort_genai_c.cpp
+++ b/src/ort_genai_c.cpp
@@ -388,16 +388,23 @@ OgaResult* OGA_API_CALL OgaGeneratorParamsSetSearchBool(OgaGeneratorParams* para
OGA_CATCH
}
-OgaResult* OGA_API_CALL OgaGeneratorParamsTryGraphCaptureWithMaxBatchSize(OgaGeneratorParams* params, int32_t max_batch_size) {
+OgaResult* OGA_API_CALL OgaGeneratorParamsSetGuidance(OgaGeneratorParams* params, const char* type, const char* data, bool enable_ff_tokens) {
OGA_TRY
- printf("TryGraphCaptureWithMaxBatchSize is deprecated and will be removed in a future release\n");
+ params->SetGuidance(type, data, enable_ff_tokens);
return nullptr;
OGA_CATCH
}
-OgaResult* OGA_API_CALL OgaGeneratorParamsSetGuidance(OgaGeneratorParams* params, const char* type, const char* data, bool enable_ff_tokens) {
+OgaResult* OGA_API_CALL OgaGeneratorParamsGetSearchNumber(const OgaGeneratorParams* params, const char* name, double* value) {
OGA_TRY
- params->SetGuidance(type, data, enable_ff_tokens);
+ *value = params->GetSearchNumber(name);
+ return nullptr;
+ OGA_CATCH
+}
+
+OgaResult* OGA_API_CALL OgaGeneratorParamsGetSearchBool(const OgaGeneratorParams* params, const char* name, bool* value) {
+ OGA_TRY
+ *value = params->GetSearchBool(name);
return nullptr;
OGA_CATCH
}
@@ -457,6 +464,10 @@ OgaResult* OGA_API_CALL OgaGenerator_AppendTokens(OgaGenerator* generator, const
OGA_CATCH
}
+size_t OGA_API_CALL OgaGenerator_TokenCount(const OgaGenerator* generator) {
+ return generator->TokenCount();
+}
+
OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken(OgaGenerator* generator) {
OGA_TRY
generator->GenerateNextToken();
diff --git a/src/ort_genai_c.h b/src/ort_genai_c.h
index 7b7c767587..de7d4fa484 100644
--- a/src/ort_genai_c.h
+++ b/src/ort_genai_c.h
@@ -411,9 +411,23 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateGeneratorParams(const OgaModel* mode
*/
OGA_EXPORT void OGA_API_CALL OgaDestroyGeneratorParams(OgaGeneratorParams* params);
+/**
+ * \brief Set a numerical value for a search parameter
+ * \param[in] params The generator params to set.
+ * \param[in] name The name of the search parameter.
+ * \param[in] value The value of the search parameter.
+ * \return OgaResult containing the error message if setting the generator params failed.
+ */
OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetSearchNumber(OgaGeneratorParams* params, const char* name, double value);
+
+/**
+ * \brief Set a boolean value for a search parameter
+ * \param[in] params The generator params to set.
+ * \param[in] name The name of the search parameter.
+ * \param[in] value The value of the search parameter.
+ * \return OgaResult containing the error message if setting the generator params failed.
+ */
OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetSearchBool(OgaGeneratorParams* params, const char* name, bool value);
-OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsTryGraphCaptureWithMaxBatchSize(OgaGeneratorParams* params, int32_t max_batch_size);
/**
* \brief Sets the guidance type and data for the Generator params
@@ -425,6 +439,24 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsTryGraphCaptureWithMaxBatch
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetGuidance(OgaGeneratorParams* params, const char* type, const char* data, bool enable_ff_tokens);
+/**
+ * \brief Get a numerical value for a search parameter
+ * \param[in] params The generator params to set.
+ * \param[in] name The name of the search parameter.
+ * \param[out] value The value of the search parameter.
+ * \return OgaResult containing the error message if setting the generator params failed.
+ */
+OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsGetSearchNumber(const OgaGeneratorParams* params, const char* name, double* value);
+
+/**
+ * \brief Get a boolean value for a search parameter
+ * \param[in] params The generator params to set.
+ * \param[in] name The name of the search parameter.
+ * \param[out] value The value of the search parameter.
+ * \return OgaResult containing the error message if setting the generator params failed.
+ */
+OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsGetSearchBool(const OgaGeneratorParams* params, const char* name, bool* value);
+
/**
* \brief Creates a generator from the given model and generator params.
* \param[in] model The model to use for generation.
@@ -446,6 +478,12 @@ OGA_EXPORT void OGA_API_CALL OgaDestroyGenerator(OgaGenerator* generator);
* \return True if the generator has finished generating all the sequences, false otherwise.
*/
OGA_EXPORT bool OGA_API_CALL OgaGenerator_IsDone(OgaGenerator* generator);
+
+/**
+ * \brief Returns true if the session has been terminated.
+ * \param[in] generator The generator to add the inputs to.
+ * \return True if the session has been terminated, false otherwise.
+ */
OGA_EXPORT bool OGA_API_CALL OgaGenerator_IsSessionTerminated(const OgaGenerator* generator);
/**
@@ -481,6 +519,13 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_AppendTokenSequences(OgaGenerato
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_AppendTokens(OgaGenerator* generator, const int32_t* input_ids, size_t input_ids_count);
+/**
+ * \brief Returns the number of tokens in the generator
+ * \param[in] generator The generator containing the appended tokens.
+ * \return The number of tokens that have been added.
+ */
+OGA_EXPORT size_t OGA_API_CALL OgaGenerator_TokenCount(const OgaGenerator* generator);
+
/**
* \brief Computes the logits from the model based on the input ids and the past state. The computed logits are stored in the generator.
* \param[in] generator The generator to compute the logits for.
@@ -497,6 +542,13 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken(OgaGenerator*
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_GetNextTokens(const OgaGenerator* generator, const int32_t** out, size_t* out_count);
+/**
+ * \brief Set a runtime option's name and value.
+ * \param[in] generator The generator to rewind to the given length.
+ * \param[in] key The runtime option's name
+ * \param[in] value The runtime option's value
+ * \return OgaResult containing the error message if setting the runtime option failed.
+ */
OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_SetRuntimeOption(OgaGenerator* generator, const char* key, const char* value);
/**
@@ -611,17 +663,27 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaUpdateTokenizerOptions(
size_t num_options);
/**
- * Return the int representation of the BOS token
+ * \brief Return the int representation of the BOS token
+ * \param[in] tokenizer The tokenizer to read from
+ * \param[out] token_id The BOS token id
+ * \return OgaResult containing the error message if returning the BOS token id fails.
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerGetBosTokenId(const OgaTokenizer* tokenizer, int32_t* token_id);
/**
- * Return an array containing the int representations of the EOS tokens. The array is owned by the tokenizer and will be freed when the tokenizer is destroyed.
+ * \brief Return an array containing the int representations of the EOS token ids. The array is owned by the tokenizer and will be freed when the tokenizer is destroyed.
+ * \param[in] tokenizer The tokenizer to read from
+ * \param[out] eos_token_ids The array of EOS token ids
+ * \param[out] token_count The length of the array
+ * \return OgaResult containing the error message if returning the EOS token ids fails.
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerGetEosTokenIds(const OgaTokenizer* tokenizer, const int32_t** eos_token_ids, size_t* token_count);
/**
- * Return the int representation of the BOS token
+ * \brief Return the int representation of the PAD token
+ * \param[in] tokenizer The tokenizer to read from
+ * \param[out] token_id The PAD token id
+ * \return OgaResult containing the error message if returning the PAD token id fails.
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerGetPadTokenId(const OgaTokenizer* tokenizer, int32_t* token_id);
diff --git a/src/python/python.cpp b/src/python/python.cpp
index 547a116b0f..3626006b32 100644
--- a/src/python/python.cpp
+++ b/src/python/python.cpp
@@ -192,14 +192,32 @@ struct PyGeneratorParams {
}
}
- void TryGraphCaptureWithMaxBatchSize(pybind11::int_ max_batch_size) {
- std::cerr << "TryGraphCaptureWithMaxBatchSize is deprecated and will be removed in a future release" << std::endl;
- }
-
void SetGuidance(const std::string& type, const std::string& data, bool enable_ff_tokens = false) {
params_->SetGuidance(type.c_str(), data.c_str(), enable_ff_tokens);
}
+ pybind11::dict GetSearchOptions() {
+ pybind11::dict d;
+ d["batch_size"] = params_->GetSearchNumber("batch_size");
+ d["chunk_size"] = params_->GetSearchNumber("chunk_size");
+ d["diversity_penalty"] = params_->GetSearchNumber("diversity_penalty");
+ d["do_sample"] = params_->GetSearchBool("do_sample");
+ d["early_stopping"] = params_->GetSearchBool("early_stopping");
+ d["length_penalty"] = params_->GetSearchNumber("length_penalty");
+ d["max_length"] = params_->GetSearchNumber("max_length");
+ d["min_length"] = params_->GetSearchNumber("min_length");
+ d["no_repeat_ngram_size"] = params_->GetSearchNumber("no_repeat_ngram_size");
+ d["num_beams"] = params_->GetSearchNumber("num_beams");
+ d["num_return_sequences"] = params_->GetSearchNumber("num_return_sequences");
+ d["past_present_share_buffer"] = params_->GetSearchBool("past_present_share_buffer");
+ d["random_seed"] = params_->GetSearchNumber("random_seed");
+ d["repetition_penalty"] = params_->GetSearchNumber("repetition_penalty");
+ d["temperature"] = params_->GetSearchNumber("temperature");
+ d["top_k"] = params_->GetSearchNumber("top_k");
+ d["top_p"] = params_->GetSearchNumber("top_p");
+ return d;
+ }
+
std::vector refs_; // References to data we want to ensure doesn't get garbage collected
};
@@ -240,6 +258,10 @@ struct PyGenerator {
generator_->AppendTokens(ToSpan(tokens));
}
+ size_t TokenCount() const {
+ return generator_->TokenCount();
+ }
+
pybind11::array_t GetLogits() {
return ToNumpy(*generator_->GetLogits());
}
@@ -316,11 +338,11 @@ PYBIND11_MODULE(onnxruntime_genai, m) {
pybind11::class_(m, "GeneratorParams")
.def(pybind11::init())
- .def("try_graph_capture_with_max_batch_size", &PyGeneratorParams::TryGraphCaptureWithMaxBatchSize)
.def("set_search_options", &PyGeneratorParams::SetSearchOptions) // See config.h 'struct Search' for the options
.def("set_guidance", &PyGeneratorParams::SetGuidance,
pybind11::arg("type"), pybind11::arg("data"),
- pybind11::arg("enable_ff_tokens") = false);
+ pybind11::arg("enable_ff_tokens") = false)
+ .def("get_search_options", &PyGeneratorParams::GetSearchOptions);
pybind11::class_(m, "TokenizerStream")
.def("decode", [](OgaTokenizerStream& t, int32_t token) { return t.Decode(token); });
@@ -461,6 +483,7 @@ PYBIND11_MODULE(onnxruntime_genai, m) {
.def("set_model_input", &PyGenerator::SetModelInput)
.def("append_tokens", pybind11::overload_cast&>(&PyGenerator::AppendTokens))
.def("append_tokens", pybind11::overload_cast(&PyGenerator::AppendTokens))
+ .def("token_count", &PyGenerator::TokenCount)
.def("get_logits", &PyGenerator::GetLogits)
.def("set_logits", &PyGenerator::SetLogits)
.def("generate_next_token", &PyGenerator::GenerateNextToken)
diff --git a/src/search.cpp b/src/search.cpp
index 1d69c30115..3e9d545c56 100644
--- a/src/search.cpp
+++ b/src/search.cpp
@@ -49,6 +49,17 @@ BeamSearch_Cpu::BeamSearch_Cpu(const GeneratorParams& params)
BeamSearch_Cpu::~BeamSearch_Cpu() = default;
+void Search_Cpu::ResetDone() {
+ // Reset done count/state
+ done_ = false;
+}
+
+void GreedySearch_Cpu::ResetDone() {
+ Search_Cpu::ResetDone();
+ not_done_count_ = params_->search.batch_size;
+ memset(eos_seen_.data(), 0, eos_seen_.size_bytes());
+}
+
DeviceSpan Search_Cpu::GetLogits() const {
return next_token_scores_;
}
@@ -333,16 +344,13 @@ void GreedySearch_Cpu::AppendTokens(DeviceSpan& next_tokens) {
}
AppendNextTokensToSequences();
}
- // Reset done count/state
- done_ = false;
- not_done_count_ = params_->search.batch_size;
- memset(eos_seen_.data(), 0, eos_seen_.size_bytes());
+
+ ResetDone();
}
void GreedySearch_Cpu::RewindTo(size_t index) {
- done_ = false;
- not_done_count_ = params_->search.batch_size;
- memset(eos_seen_.data(), 0, eos_seen_.size_bytes());
+ ResetDone();
+
// Set next tokens to the last tokens in the sequence
if (index > 0) {
for (int i = 0; i < params_->BatchBeamSize(); i++) {
diff --git a/src/search.h b/src/search.h
index 9d456368c1..b5d19de908 100644
--- a/src/search.h
+++ b/src/search.h
@@ -44,6 +44,7 @@ struct Search_Cpu : Search {
DeviceSpan GetSequenceLengths() override { return sequence_lengths_; }
bool IsDone() const override { return done_; }
+ void ResetDone();
DeviceSpan GetLogits() const override;
void SetLogits(DeviceSpan logits) override;
@@ -75,6 +76,7 @@ struct GreedySearch_Cpu : Search_Cpu {
void SampleTopKTopP(int /*k*/, float /*p*/, float /*temperature*/) override;
// Used by continuous decoding search.
+ void ResetDone();
void AppendTokens(DeviceSpan& next_tokens) override;
void RewindTo(size_t index) override;
diff --git a/test/c_api_tests.cpp b/test/c_api_tests.cpp
index 08af900b56..330ff1b43c 100644
--- a/test/c_api_tests.cpp
+++ b/test/c_api_tests.cpp
@@ -338,11 +338,8 @@ TEST(CAPITests, EndToEndPhiBatch) {
auto generator = OgaGenerator::Create(*model, *params);
generator->AppendTokenSequences(*input_sequences);
- while (true) {
+ while (!generator->IsDone()) {
generator->GenerateNextToken();
- if (generator->IsDone()) {
- break;
- }
}
// Decode The Batch
@@ -519,11 +516,8 @@ TEST(CAPITests, EndToEndPhi) {
auto generator = OgaGenerator::Create(*model, *params);
generator->AppendTokenSequences(*input_sequence);
- while (true) {
+ while (!generator->IsDone()) {
generator->GenerateNextToken();
- if (generator->IsDone()) {
- break;
- }
}
// Decode The Batch
@@ -561,11 +555,12 @@ TEST(CAPITests, EndToEndPhiEOSPAD) {
auto generator = OgaGenerator::Create(*model, *params);
generator->AppendTokenSequences(*input_sequence);
- while (true) {
+ ASSERT_EQ(static_cast(params->GetSearchNumber("max_length")), 40);
+ ASSERT_EQ(params->GetSearchBool("early_stopping"), true);
+ ASSERT_EQ(static_cast(generator->TokenCount()), static_cast(generator->GetSequenceCount(0)));
+
+ while (!generator->IsDone()) {
generator->GenerateNextToken();
- if (generator->IsDone()) {
- break;
- }
}
// Decode The Batch
@@ -582,6 +577,7 @@ TEST(CAPITests, EndToEndPhiEOSPAD) {
const auto* sequence_data = generator->GetSequenceData(0);
ASSERT_LE(sequence_length, 40);
+ ASSERT_EQ(static_cast(generator->TokenCount()), static_cast(generator->GetSequenceCount(0)));
const auto* expected_output_start = &expected_output[0];
EXPECT_TRUE(0 == std::memcmp(expected_output_start, sequence_data, sequence_length * sizeof(int32_t)));
@@ -658,11 +654,8 @@ TEST(CAPITests, LoadModelFromMemory) {
auto generator = OgaGenerator::Create(*model, *params);
generator->AppendTokenSequences(*input_sequence);
- while (true) {
+ while (!generator->IsDone()) {
generator->GenerateNextToken();
- if (generator->IsDone()) {
- break;
- }
}
// Decode The Batch
@@ -738,11 +731,8 @@ TEST(CAPITests, GreedySearchGptFp32CAPI) {
auto generator = OgaGenerator::Create(*model, *params);
generator->AppendTokens(input_ids.data(), input_ids.size());
- while (true) {
+ while (!generator->IsDone()) {
generator->GenerateNextToken();
- if (generator->IsDone()) {
- break;
- }
}
// Verify outputs match expected outputs
@@ -903,11 +893,8 @@ TEST(CAPITests, SetTerminate) {
auto GenerateOutput = [](OgaGenerator* generator, std::unique_ptr tokenizer_stream) {
EXPECT_THROW({
- while (true) {
+ while (!generator->IsDone()) {
generator->GenerateNextToken();
- if (generator->IsDone()) {
- break;
- }
} }, std::runtime_error);
};
@@ -968,11 +955,8 @@ struct Phi2Test {
auto generator = OgaGenerator::Create(*model_, *params_);
generator->AppendTokenSequences(*input_sequences_);
- while (true) {
+ while (!generator->IsDone()) {
generator->GenerateNextToken();
- if (generator->IsDone()) {
- break;
- }
}
// Decode One at a time
@@ -1130,11 +1114,8 @@ TEST(CAPITests, AdaptersTest) {
auto generator = OgaGenerator::Create(*model, *params);
generator->AppendTokenSequences(*input_sequences);
- while (true) {
+ while (!generator->IsDone()) {
generator->GenerateNextToken();
- if (generator->IsDone()) {
- break;
- }
}
auto logits = generator->GetOutput("logits");
@@ -1156,11 +1137,8 @@ TEST(CAPITests, AdaptersTest) {
generator->SetActiveAdapter(*adapters, "adapters_a_and_b");
generator->AppendTokenSequences(*input_sequences);
- while (true) {
+ while (!generator->IsDone()) {
generator->GenerateNextToken();
- if (generator->IsDone()) {
- break;
- }
}
auto logits = generator->GetOutput("logits");
@@ -1210,11 +1188,8 @@ TEST(CAPITests, AdaptersTestMultipleAdapters) {
generator->SetActiveAdapter(*adapters, "adapter_b");
generator->AppendTokenSequences(*input_sequences);
- while (true) {
+ while (!generator->IsDone()) {
generator->GenerateNextToken();
- if (generator->IsDone()) {
- break;
- }
}
}
@@ -1256,11 +1231,8 @@ TEST(CAPITests, BatchedRewindGptFp32CAPI) {
auto generator = OgaGenerator::Create(*model, *params);
generator->AppendTokens(input_ids.data(), input_ids.size());
- while (true) {
+ while (!generator->IsDone()) {
generator->GenerateNextToken();
- if (generator->IsDone()) {
- break;
- }
}
// Verify outputs match expected outputs
@@ -1278,11 +1250,8 @@ TEST(CAPITests, BatchedRewindGptFp32CAPI) {
generator->RewindTo(0);
generator->AppendTokens(input_ids.data(), input_ids.size());
- while (true) {
+ while (!generator->IsDone()) {
generator->GenerateNextToken();
- if (generator->IsDone()) {
- break;
- }
}
// Verify outputs match expected outputs
@@ -1316,11 +1285,8 @@ TEST(CAPITests, RewindGptFp32CAPI) {
auto generator = OgaGenerator::Create(*model, *params);
generator->AppendTokens(input_ids.data(), input_ids.size());
- while (true) {
+ while (!generator->IsDone()) {
generator->GenerateNextToken();
- if (generator->IsDone()) {
- break;
- }
}
// Verify outputs match expected outputs
@@ -1334,11 +1300,8 @@ TEST(CAPITests, RewindGptFp32CAPI) {
// Rewind to length 5 and verify same output
generator->RewindTo(5);
- while (true) {
+ while (!generator->IsDone()) {
generator->GenerateNextToken();
- if (generator->IsDone()) {
- break;
- }
}
// Verify outputs match expected outputs
@@ -1353,11 +1316,8 @@ TEST(CAPITests, RewindGptFp32CAPI) {
std::vector next_ids{731, 731};
generator->AppendTokens(next_ids.data(), next_ids.size());
- while (true) {
+ while (!generator->IsDone()) {
generator->GenerateNextToken();
- if (generator->IsDone()) {
- break;
- }
}
// Verify outputs match expected outputs
@@ -1386,11 +1346,8 @@ TEST(CAPITests, SetGuidance) {
auto generator = OgaGenerator::Create(*model, *params);
generator->AppendTokenSequences(*input_sequences);
- while (true) {
+ while (!generator->IsDone()) {
generator->GenerateNextToken();
- if (generator->IsDone()) {
- break;
- }
}
auto out_string = tokenizer->Decode(generator->GetSequenceData(0), generator->GetSequenceCount(0));
auto output = std::string(out_string).substr(std::string(input_string).size());
diff --git a/test/csharp/TestOnnxRuntimeGenAIAPI.cs b/test/csharp/TestOnnxRuntimeGenAIAPI.cs
index 3b1c72d395..4e6a05b5cd 100644
--- a/test/csharp/TestOnnxRuntimeGenAIAPI.cs
+++ b/test/csharp/TestOnnxRuntimeGenAIAPI.cs
@@ -158,16 +158,16 @@ public void TestGreedySearch()
using (var generator = new Generator(model, generatorParams))
{
Assert.NotNull(generator);
-
generator.AppendTokens(inputIDs);
+
Assert.False(generator.IsDone());
- while (true)
+ Assert.Equal(generatorParams.GetSearchNumber("max_length"), maxLength);
+ Assert.Equal(generatorParams.GetSearchBool("early_stopping"), true);
+ Assert.Equal((int)generator.TokenCount(), generator.GetSequence(0).Length);
+
+ while (!generator.IsDone())
{
generator.GenerateNextToken();
- if (generator.IsDone())
- {
- break;
- }
}
for (ulong i = 0; i < batchSize; i++)
@@ -175,6 +175,7 @@ public void TestGreedySearch()
var sequence = generator.GetSequence(i).ToArray();
var expectedSequence = expectedOutput.Skip((int)i * (int)maxLength).Take((int)maxLength);
Assert.Equal(expectedSequence, sequence);
+ Assert.Equal((int)generator.TokenCount(), generator.GetSequence(i).Length);
}
}
}
@@ -216,13 +217,9 @@ public void TestLoadModelFromMemory()
generator.AppendTokens(inputIDs);
Assert.False(generator.IsDone());
- while (true)
+ while (!generator.IsDone())
{
generator.GenerateNextToken();
- if (generator.IsDone())
- {
- break;
- }
}
for (ulong i = 0; i < batchSize; i++)
@@ -296,13 +293,9 @@ public void TestTopKSearch()
generator.AppendTokenSequences(sequences);
- while (true)
+ while (!generator.IsDone())
{
generator.GenerateNextToken();
- if (generator.IsDone())
- {
- break;
- }
}
for (ulong i = 0; i < batchSize; i++)
@@ -359,13 +352,9 @@ public void TestTopPSearch()
generator.AppendTokenSequences(sequences);
- while (true)
+ while (!generator.IsDone())
{
generator.GenerateNextToken();
- if (generator.IsDone())
- {
- break;
- }
}
for (ulong i = 0; i < batchSize; i++)
@@ -424,13 +413,9 @@ public void TestTopKTopPSearch()
generator.AppendTokenSequences(sequences);
- while (true)
+ while (!generator.IsDone())
{
generator.GenerateNextToken();
- if (generator.IsDone())
- {
- break;
- }
}
for (ulong i = 0; i < batchSize; i++)
@@ -635,13 +620,9 @@ public void TestPhi2()
generator.AppendTokenSequences(sequences);
- while (true)
+ while (!generator.IsDone())
{
generator.GenerateNextToken();
- if (generator.IsDone())
- {
- break;
- }
}
for (ulong i = 0; i < batchSize; i++)
@@ -745,13 +726,9 @@ public void TestAdapters()
using var generator = new Generator(model, genParams);
generator.AppendTokenSequences(sequences);
- while (true)
+ while (!generator.IsDone())
{
generator.GenerateNextToken();
- if (generator.IsDone())
- {
- break;
- }
}
using var logits = generator.GetOutput("logits");
@@ -777,13 +754,9 @@ public void TestAdapters()
using var generator = new Generator(model, genParams);
generator.SetActiveAdapter(adapters, "adapters_a_and_b");
generator.AppendTokenSequences(sequences);
- while (true)
+ while (!generator.IsDone())
{
generator.GenerateNextToken();
- if (generator.IsDone())
- {
- break;
- }
}
using var logits = generator.GetOutput("logits");
if (_useCudaModel)
diff --git a/test/model_tests.cpp b/test/model_tests.cpp
index 70dd8c2ea4..25bb1b8b96 100644
--- a/test/model_tests.cpp
+++ b/test/model_tests.cpp
@@ -84,11 +84,8 @@ TEST(ModelTests, GreedySearchGptFp32) {
auto generator = OgaGenerator::Create(*model, *params);
generator->AppendTokens(input_ids);
- while (true) {
+ while (!generator->IsDone()) {
generator->GenerateNextToken();
- if (generator->IsDone()) {
- break;
- }
}
// Verify outputs match expected outputs
@@ -128,11 +125,8 @@ TEST(ModelTests, BeamSearchGptFp32) {
auto generator = OgaGenerator::Create(*model, *params);
generator->AppendTokens(input_ids);
- while (true) {
+ while (!generator->IsDone()) {
generator->GenerateNextToken();
- if (generator->IsDone()) {
- break;
- }
}
// Verify outputs match expected outputs
@@ -165,11 +159,8 @@ void Test_GreedySearch_Gpt_Cuda(const char* model_path, const char* model_label)
auto generator = OgaGenerator::Create(*model, *params);
generator->AppendTokens(input_ids);
- while (true) {
+ while (!generator->IsDone()) {
generator->GenerateNextToken();
- if (generator->IsDone()) {
- break;
- }
}
// Verify outputs match expected outputs
@@ -190,11 +181,8 @@ void Test_GreedySearch_Gpt_Cuda(const char* model_path, const char* model_label)
generator = OgaGenerator::Create(*model, *params);
generator->AppendTokens(input_ids);
- while (true) {
+ while (!generator->IsDone()) {
generator->GenerateNextToken();
- if (generator->IsDone()) {
- break;
- }
}
// Verify outputs match expected outputs
@@ -205,11 +193,8 @@ void Test_GreedySearch_Gpt_Cuda(const char* model_path, const char* model_label)
generator->RewindTo(3);
std::vector next_ids{731, 731};
generator->AppendTokens(next_ids);
- while (true) {
+ while (!generator->IsDone()) {
generator->GenerateNextToken();
- if (generator->IsDone()) {
- break;
- }
}
// Verify outputs match expected outputs
@@ -251,11 +236,8 @@ void Test_BeamSearch_Gpt_Cuda(const char* model_path, const char* model_label) {
auto generator = OgaGenerator::Create(*model, *params);
generator->AppendTokens(input_ids);
- while (true) {
+ while (!generator->IsDone()) {
generator->GenerateNextToken();
- if (generator->IsDone()) {
- break;
- }
}
// Verify outputs match expected outputs
@@ -306,11 +288,8 @@ void Test_GreedySearch_Phi3_NvTensorRtRtx(const char* model_path, const char* mo
auto generator = OgaGenerator::Create(*model, *params);
generator->AppendTokens(input_ids);
- while (true) {
+ while (!generator->IsDone()) {
generator->GenerateNextToken();
- if (generator->IsDone()) {
- break;
- }
}
// Verify outputs match expected outputs
@@ -362,11 +341,8 @@ void Test_OutOfPlaceKvCache_Phi3_NvTensorRtRtx(const char* model_path, const cha
auto generator = OgaGenerator::Create(*model, *params);
generator->AppendTokens(input_ids);
- while (true) {
+ while (!generator->IsDone()) {
generator->GenerateNextToken();
- if (generator->IsDone()) {
- break;
- }
}
auto sequence = generator->GetSequence(0);
@@ -404,11 +380,8 @@ Print all primes between 1 and n
auto generator = OgaGenerator::Create(*model, *params);
generator->AppendTokens(tokens->Get(0));
- while (true) {
+ while (!generator->IsDone()) {
generator->GenerateNextToken();
- if (generator->IsDone()) {
- break;
- }
}
auto result = generator->GetSequence(0);
@@ -438,11 +411,8 @@ Print all primes between 1 and n
auto generator = OgaGenerator::Create(*model, *params);
generator->AppendTokens(tokens->Get(0));
- while (true) {
+ while (!generator->IsDone()) {
generator->GenerateNextToken();
- if (generator->IsDone()) {
- break;
- }
}
auto result = generator->GetSequence(0);
diff --git a/test/platform/apple/apple_package_test/ios_package_testUITests/ios_package_uitest_cpp_api.mm b/test/platform/apple/apple_package_test/ios_package_testUITests/ios_package_uitest_cpp_api.mm
index 9ff3bee78c..25ed5e8172 100644
--- a/test/platform/apple/apple_package_test/ios_package_testUITests/ios_package_uitest_cpp_api.mm
+++ b/test/platform/apple/apple_package_test/ios_package_testUITests/ios_package_uitest_cpp_api.mm
@@ -54,16 +54,18 @@ - (void)testCppAPI_Basic {
auto generator = OgaGenerator::Create(*model, *params);
generator->AppendTokenSequences(*sequences);
- while (true) {
+ XCTAssertEqual(static_cast(params->GetSearchNumber("max_length")), 100);
+ XCTAssertEqual(params->GetSearchBool("early_stopping"), true);
+ XCTAssertEqual(static_cast(generator->TokenCount()), static_cast(generator->GetSequenceCount(0)));
+
+ while (!generator->IsDone()) {
generator->GenerateNextToken();
- if (generator->IsDone()) {
- break;
- }
}
const auto output_sequence_length = generator->GetSequenceCount(0);
const auto* output_sequence_data = generator->GetSequenceData(0);
auto out_string = tokenizer->Decode(output_sequence_data, output_sequence_length);
+ XCTAssertEqual(static_cast(generator->TokenCount()), static_cast(generator->GetSequenceCount(0)));
}
@end
diff --git a/test/platform/apple/apple_package_test/macos_package_testUITests/macos_package_uitest_cpp_api.mm b/test/platform/apple/apple_package_test/macos_package_testUITests/macos_package_uitest_cpp_api.mm
index a4e5fdfc63..538289c322 100644
--- a/test/platform/apple/apple_package_test/macos_package_testUITests/macos_package_uitest_cpp_api.mm
+++ b/test/platform/apple/apple_package_test/macos_package_testUITests/macos_package_uitest_cpp_api.mm
@@ -54,16 +54,18 @@ - (void)testCppAPI_Basic {
auto generator = OgaGenerator::Create(*model, *params);
generator->AppendTokenSequences(*sequences);
- while (true) {
+ XCTAssertEqual(static_cast(params->GetSearchNumber("max_length")), 100);
+ XCTAssertEqual(params->GetSearchBool("early_stopping"), true);
+ XCTAssertEqual(static_cast(generator->TokenCount()), static_cast(generator->GetSequenceCount(0)));
+
+ while (!generator->IsDone()) {
generator->GenerateNextToken();
- if (generator->IsDone()) {
- break;
- }
}
const auto output_sequence_length = generator->GetSequenceCount(0);
const auto* output_sequence_data = generator->GetSequenceData(0);
auto out_string = tokenizer->Decode(output_sequence_data, output_sequence_length);
+ XCTAssertEqual(static_cast(generator->TokenCount()), static_cast(generator->GetSequenceCount(0)));
}
@end
diff --git a/test/python/requirements.txt b/test/python/requirements.txt
index 3a851be2e5..2d7c639bdf 100644
--- a/test/python/requirements.txt
+++ b/test/python/requirements.txt
@@ -7,5 +7,5 @@ sympy
pytest
onnx
onnx_ir>=0.1.3
-transformers
+transformers<5.0.0
huggingface_hub[cli]
diff --git a/test/python/test_onnxruntime_genai_api.py b/test/python/test_onnxruntime_genai_api.py
index cd6dd43c2d..837bd4726a 100644
--- a/test/python/test_onnxruntime_genai_api.py
+++ b/test/python/test_onnxruntime_genai_api.py
@@ -148,15 +148,18 @@ def test_greedy_search(test_data_path, relative_model_path):
generator = og.Generator(model, search_params)
generator.append_tokens(np.array([[0, 0, 0, 52], [0, 0, 195, 731]], dtype=np.int32))
- while True:
+
+ assert int(search_params.get_search_options()["max_length"]) == 10
+ assert search_params.get_search_options()["early_stopping"] == True
+ assert int(generator.token_count()) == 4
+
+ while not generator.is_done():
# Test getting/setting logits
logits = generator.get_logits()
generator.set_logits(logits)
generator.set_logits(logits) # twice just to be sure buffer is still valid
generator.generate_next_token()
- if generator.is_done():
- break
expected_sequence = np.array(
[
@@ -167,7 +170,7 @@ def test_greedy_search(test_data_path, relative_model_path):
)
for i in range(batch_size):
assert np.array_equal(expected_sequence[i], generator.get_sequence(i))
-
+ assert int(generator.token_count()) == len(generator.get_sequence(0))
@pytest.mark.parametrize(
"relative_model_path",
@@ -193,20 +196,16 @@ def test_rewind_cuda(test_data_path, relative_model_path):
generator = og.Generator(model, search_params)
generator.append_tokens(np.array([[0, 0, 195, 731]], dtype=np.int32))
- while True:
+ while not generator.is_done():
generator.generate_next_token()
- if generator.is_done():
- break
assert generator.get_sequence(0) is not None
generator.rewind_to(3)
generator.append_tokens(np.array([[731, 731]], dtype=np.int32))
- while True:
+ while not generator.is_done():
generator.generate_next_token()
- if generator.is_done():
- break
assert generator.get_sequence(0) is not None
@@ -217,10 +216,8 @@ def test_rewind_cuda(test_data_path, relative_model_path):
generator = og.Generator(model, search_params)
generator.append_tokens(np.array([[0, 0, 0, 52], [0, 0, 195, 731], [64, 65, 66, 67]], dtype=np.int32))
- while True:
+ while not generator.is_done():
generator.generate_next_token()
- if generator.is_done():
- break
for i in range(batch_size):
assert generator.get_sequence(i) is not None
@@ -233,10 +230,8 @@ def test_rewind_cuda(test_data_path, relative_model_path):
dtype=np.int32,
)
)
- while True:
+ while not generator.is_done():
generator.generate_next_token()
- if generator.is_done():
- break
for i in range(batch_size):
assert generator.get_sequence(i) is not None
@@ -262,20 +257,16 @@ def test_rewind(test_data_path, relative_model_path):
generator = og.Generator(model, search_params)
generator.append_tokens(np.array([[0, 0, 195, 731]], dtype=np.int32))
- while True:
+ while not generator.is_done():
generator.generate_next_token()
- if generator.is_done():
- break
assert np.array_equal(expected_sequence, generator.get_sequence(0))
generator.rewind_to(3)
generator.append_tokens(np.array([[731, 731]], dtype=np.int32))
- while True:
+ while not generator.is_done():
generator.generate_next_token()
- if generator.is_done():
- break
assert np.array_equal(expected_sequence, generator.get_sequence(0))
@@ -403,10 +394,8 @@ def test_batching(device, phi2_for):
generator = og.Generator(model, params)
generator.append_tokens(tokenizer.encode_batch(prompts))
- while True:
+ while not generator.is_done():
generator.generate_next_token()
- if generator.is_done():
- break
for i in range(len(prompts)):
print(tokenizer.decode(generator.get_sequence(0)))
@@ -434,10 +423,8 @@ def test_e2e(device, phi2_for):
generator = og.Generator(model, params)
generator.append_tokens(tokenizer.encode_batch(prompts))
- while True:
+ while not generator.is_done():
generator.generate_next_token()
- if generator.is_done():
- break
for i in range(len(prompts)):
print(tokenizer.decode(generator.get_sequence(0)))
@@ -469,10 +456,8 @@ def test_load_model_from_memory(device, wrapper_bytes_function, phi2_for):
generator = og.Generator(model, params)
generator.append_tokens(tokenizer.encode_batch(prompts))
- while True:
+ while not generator.is_done():
generator.generate_next_token()
- if generator.is_done():
- break
for i in range(len(prompts)):
print(tokenizer.decode(generator.get_sequence(0)))
@@ -648,10 +633,8 @@ def _split(onnx_model_path: os.PathLike, output_dir: os.PathLike):
generator = og.Generator(model, params)
generator.append_tokens(tokenizer.encode_batch(prompts))
- while True:
+ while not generator.is_done():
generator.generate_next_token()
- if generator.is_done():
- break
expected_output = [
"This is a test.\n # TOD import * doct proofingrad",
@@ -848,10 +831,8 @@ def _export_adapter(adapter, adapter_file_name):
generator.set_active_adapter(adapters, f"adapter_{i}")
generator.append_tokens(tokenizer.encode_batch(prompts))
- while True:
+ while not generator.is_done():
generator.generate_next_token()
- if generator.is_done():
- break
@pytest.mark.parametrize("device", devices)
@@ -935,10 +916,8 @@ def _prepare_model(test_data_path):
else:
generator.append_tokens(tokenizer.encode_batch(prompts))
- while True:
+ while not generator.is_done():
generator.generate_next_token()
- if generator.is_done():
- break
@pytest.mark.parametrize("relative_model_path", [Path("audio-preprocessing")])
diff --git a/test/python/test_onnxruntime_genai_e2e.py b/test/python/test_onnxruntime_genai_e2e.py
index 643882ae5c..f28d674e21 100644
--- a/test/python/test_onnxruntime_genai_e2e.py
+++ b/test/python/test_onnxruntime_genai_e2e.py
@@ -31,10 +31,8 @@ def run_model(model_path: str | bytes | os.PathLike):
generator = og.Generator(model, params)
generator.append_tokens(sequences)
- while True:
+ while not generator.is_done():
generator.generate_next_token()
- if generator.is_done():
- break
for i in range(3):
assert generator.get_sequence(i) is not None