diff --git a/README.md b/README.md index 2fd0d344f5..29313d8feb 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ See documentation at the [ONNX Runtime website](https://onnxruntime.ai/docs/gena | Support matrix | Supported now | Under development | On the roadmap| | -------------- | ------------- | ----------------- | -------------- | -| Model architectures | ChatGLM
DeepSeek
Ernie
Fara
Gemma
GPTOSS
Granite
Llama
Mistral
Nemotron
OLMo
Phi
Phi3V
Phi4MM
Qwen
Qwen-2.5VL
SmolLM3
Whisper
| Stable diffusion || +| Model architectures | AMD OLMo
ChatGLM
DeepSeek
ERNIE 4.5
Fara
Gemma
gpt-oss
Granite
Llama
Mistral
Nemotron
Phi (language + vision)
Qwen (language + vision)
SmolLM3
Whisper | Stable diffusion | Multi-modal models | | API| Python
C#
C/C++
Java ^ | Objective-C || | O/S | Linux
Windows
Mac
Android || iOS ||| | Architecture | x86
x64
arm64 |||| @@ -78,10 +78,8 @@ See [installation instructions](https://onnxruntime.ai/docs/genai/howto/install) try: generator.append_tokens(input_tokens) - while True: + while not generator.is_done(): generator.generate_next_token() - if generator.is_done(): - break new_token = generator.get_next_tokens()[0] print(tokenizer_stream.decode(new_token), end='', flush=True) except KeyboardInterrupt: @@ -115,13 +113,13 @@ Windows: pip list | findstr "onnxruntime-genai" ``` -Checkout the version of the examples that correspond to that release. +Then, check out the version of the examples that corresponds to that release. ```bash # Clone the repo git clone https://github.com/microsoft/onnxruntime-genai.git && cd onnxruntime-genai # Checkout the branch for the version you are using -git checkout v0.11.4 +git checkout v0.11.5 cd examples ``` @@ -145,30 +143,11 @@ Navigate to the examples folder in the main branch. cd examples ``` -## Breaking API changes +To install the nightly Python build: -### v0.11.0 - -Between `v0.11.0` and `v0.10.1`, there is a breaking API usage change to improve model quality during multi-turn conversations. - -Previously, the decoding loop could be written as follows. - -``` -while not IsDone(): - GenerateToken() - GetLastToken() - PrintLastToken() -``` - -In 0.11.0, the decoding loop should now be written as follows. - -``` -while True: - GenerateToken() - if IsDone(): - break - GetLastToken() - PrintLastToken() +```bash +# Change onnxruntime-genai to the Python package you want to install +pip install --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/ onnxruntime-genai ``` ## Roadmap diff --git a/examples/c/src/model_chat.cpp b/examples/c/src/model_chat.cpp index d8d6b7a882..4efe3a79bf 100644 --- a/examples/c/src/model_chat.cpp +++ b/examples/c/src/model_chat.cpp @@ -91,7 +91,7 @@ void CXX_API(const char* model_path, const char* execution_provider) { const auto current_token_count = generator->GetSequenceCount(0); try { - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); if (is_first_token) { @@ -99,12 +99,7 @@ void CXX_API(const char* model_path, const char* execution_provider) { is_first_token = false; } - if (generator->IsDone()) { - break; - } - - const auto num_tokens = generator->GetSequenceCount(0); - const auto new_token = generator->GetSequenceData(0)[num_tokens - 1]; + const auto new_token = generator->GetNextTokens()[0]; std::cout << tokenizer_stream->Decode(new_token) << std::flush; } } catch (const std::exception& e) { diff --git a/examples/c/src/model_qa.cpp b/examples/c/src/model_qa.cpp index 9c06659057..b6b4a3c229 100644 --- a/examples/c/src/model_qa.cpp +++ b/examples/c/src/model_qa.cpp @@ -82,7 +82,7 @@ void CXX_API(const char* model_path, const char* execution_provider) { generator->AppendTokenSequences(*sequences); try { - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); if (is_first_token) { @@ -90,12 +90,7 @@ void CXX_API(const char* model_path, const char* execution_provider) { is_first_token = false; } - if (generator->IsDone()) { - break; - } - - const auto num_tokens = generator->GetSequenceCount(0); - const auto new_token = generator->GetSequenceData(0)[num_tokens - 1]; + const auto new_token = generator->GetNextTokens()[0]; std::cout << tokenizer_stream->Decode(new_token) << std::flush; } } catch (const std::exception& e) { diff --git a/examples/c/src/model_vision.cpp b/examples/c/src/model_vision.cpp index 1be5de0470..b902ecee2c 100644 --- a/examples/c/src/model_vision.cpp +++ b/examples/c/src/model_vision.cpp @@ -90,15 +90,9 @@ void CXX_API(const char* model_path, const char* execution_provider) { auto generator = OgaGenerator::Create(*model, *params); generator->SetInputs(*input_tensors); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - - if (generator->IsDone()) { - break; - } - - const auto num_tokens = generator->GetSequenceCount(0); - const auto new_token = generator->GetSequenceData(0)[num_tokens - 1]; + const auto new_token = generator->GetNextTokens()[0]; std::cout << stream->Decode(new_token) << std::flush; } diff --git a/examples/c/src/phi4-mm.cpp b/examples/c/src/phi4-mm.cpp index ab0be4d77c..e1deb781d4 100644 --- a/examples/c/src/phi4-mm.cpp +++ b/examples/c/src/phi4-mm.cpp @@ -101,15 +101,9 @@ void CXX_API(const char* model_path, const char* execution_provider) { auto generator = OgaGenerator::Create(*model, *params); generator->SetInputs(*input_tensors); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - - if (generator->IsDone()) { - break; - } - - const auto num_tokens = generator->GetSequenceCount(0); - const auto new_token = generator->GetSequenceData(0)[num_tokens - 1]; + const auto new_token = generator->GetNextTokens()[0]; std::cout << stream->Decode(new_token) << std::flush; } diff --git a/examples/c/src/whisper.cpp b/examples/c/src/whisper.cpp index 48a657972c..32e8c9e029 100644 --- a/examples/c/src/whisper.cpp +++ b/examples/c/src/whisper.cpp @@ -57,11 +57,8 @@ void CXX_API(const char* model_path, int32_t num_beams) { auto generator = OgaGenerator::Create(*model, *params); generator->SetInputs(*inputs); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { - break; - } } for (size_t i = 0; i < static_cast(num_beams * batch_size); ++i) { diff --git a/examples/csharp/HelloPhi/Program.cs b/examples/csharp/HelloPhi/Program.cs index c23ab9907c..e2f64dfc8f 100644 --- a/examples/csharp/HelloPhi/Program.cs +++ b/examples/csharp/HelloPhi/Program.cs @@ -127,13 +127,9 @@ static string GetPrompt(bool interactive) using var generator = new Generator(model, generatorParams); generator.AppendTokenSequences(sequences); var watch = System.Diagnostics.Stopwatch.StartNew(); - while (true) + while (!generator.IsDone()) { generator.GenerateNextToken(); - if (generator.IsDone()) - { - break; - } } var outputSequence = generator.GetSequence(0); @@ -155,13 +151,9 @@ static string GetPrompt(bool interactive) using var generator = new Generator(model, generatorParams); generator.AppendTokenSequences(sequences); var watch = System.Diagnostics.Stopwatch.StartNew(); - while (true) + while (!generator.IsDone()) { generator.GenerateNextToken(); - if (generator.IsDone()) - { - break; - } Console.Write(tokenizerStream.Decode(generator.GetNextTokens()[0])); } Console.WriteLine(); @@ -196,13 +188,9 @@ static string GetPrompt(bool interactive) var sequences = tokenizer.Encode(tokenizer.ApplyChatTemplate("", messages, "", true)); var watch = System.Diagnostics.Stopwatch.StartNew(); generator.AppendTokenSequences(sequences); - while (true) + while (!generator.IsDone()) { generator.GenerateNextToken(); - if (generator.IsDone()) - { - break; - } Console.Write(tokenizerStream.Decode(generator.GetNextTokens()[0])); } Console.WriteLine(); diff --git a/examples/csharp/HelloPhi3V/Program.cs b/examples/csharp/HelloPhi3V/Program.cs index 09e1038bd5..1d32cac199 100644 --- a/examples/csharp/HelloPhi3V/Program.cs +++ b/examples/csharp/HelloPhi3V/Program.cs @@ -168,13 +168,9 @@ void PrintUsage() using var generator = new Generator(model, generatorParams); generator.SetInputs(inputTensors); var watch = System.Diagnostics.Stopwatch.StartNew(); - while (true) + while (!generator.IsDone()) { generator.GenerateNextToken(); - if (generator.IsDone()) - { - break; - } Console.Write(stream.Decode(generator.GetNextTokens()[0])); } watch.Stop(); diff --git a/examples/csharp/HelloPhi4MM/Program.cs b/examples/csharp/HelloPhi4MM/Program.cs index bcec8d714f..ce0ddf359d 100644 --- a/examples/csharp/HelloPhi4MM/Program.cs +++ b/examples/csharp/HelloPhi4MM/Program.cs @@ -215,13 +215,9 @@ void PrintUsage() using var generator = new Generator(model, generatorParams); generator.SetInputs(inputTensors); var watch = System.Diagnostics.Stopwatch.StartNew(); - while (true) + while (!generator.IsDone()) { generator.GenerateNextToken(); - if (generator.IsDone()) - { - break; - } Console.Write(stream.Decode(generator.GetNextTokens()[0])); } watch.Stop(); diff --git a/examples/python/awq-quantized-model.py b/examples/python/awq-quantized-model.py index fd9991b3c7..465935e264 100644 --- a/examples/python/awq-quantized-model.py +++ b/examples/python/awq-quantized-model.py @@ -108,11 +108,8 @@ def run_model(args): print("Output: ", end="", flush=True) try: - while True: + while not generator.is_done(): generator.generate_next_token() - if generator.is_done(): - break - new_token = generator.get_next_tokens()[0] print(tokenizer_stream.decode(new_token), end="", flush=True) except KeyboardInterrupt: diff --git a/examples/python/model-chat.py b/examples/python/model-chat.py index 44c9add4ff..ca9746b16a 100644 --- a/examples/python/model-chat.py +++ b/examples/python/model-chat.py @@ -216,16 +216,13 @@ def main(args): print("Output: ", end="", flush=True) try: - while True: + while not generator.is_done(): generator.generate_next_token() if args.timings: if first: first_token_timestamp = time.time() first = False - if generator.is_done(): - break - new_token = generator.get_next_tokens()[0] print(tokenizer_stream.decode(new_token), end="", flush=True) if args.timings: diff --git a/examples/python/model-generate.py b/examples/python/model-generate.py index 649d393c4a..86be56922b 100644 --- a/examples/python/model-generate.py +++ b/examples/python/model-generate.py @@ -99,10 +99,8 @@ def main(args): if args.verbose: print("Generating tokens ...\n") start_time = time.time() - while True: + while not generator.is_done(): generator.generate_next_token() - if generator.is_done(): - break run_time = time.time() - start_time for i in range(len(prompts)): diff --git a/examples/python/model-qa.py b/examples/python/model-qa.py index cfd31cd946..a9b2cf7802 100644 --- a/examples/python/model-qa.py +++ b/examples/python/model-qa.py @@ -207,16 +207,13 @@ def main(args): print("Output: ", end="", flush=True) try: - while True: + while not generator.is_done(): generator.generate_next_token() if args.timings: if first: first_token_timestamp = time.time() first = False - if generator.is_done(): - break - new_token = generator.get_next_tokens()[0] print(tokenizer_stream.decode(new_token), end="", flush=True) if args.timings: diff --git a/examples/python/model-vision.py b/examples/python/model-vision.py index acd95f1ff1..e226697e15 100644 --- a/examples/python/model-vision.py +++ b/examples/python/model-vision.py @@ -150,11 +150,8 @@ def run(args: argparse.Namespace): generator.set_inputs(inputs) start_time = time.time() - while True: + while not generator.is_done(): generator.generate_next_token() - if generator.is_done(): - break - new_token = generator.get_next_tokens()[0] print(stream.decode(new_token), end="", flush=True) diff --git a/examples/python/phi3-qa.py b/examples/python/phi3-qa.py index e0296d1905..645954fee3 100644 --- a/examples/python/phi3-qa.py +++ b/examples/python/phi3-qa.py @@ -75,16 +75,13 @@ def main(args): print("Output: ", end="", flush=True) try: - while True: + while not generator.is_done(): generator.generate_next_token() if args.timings: if first: first_token_timestamp = time.time() first = False - if generator.is_done(): - break - new_token = generator.get_next_tokens()[0] print(tokenizer_stream.decode(new_token), end="", flush=True) if args.timings: diff --git a/examples/python/phi4-mm.py b/examples/python/phi4-mm.py index 80a1413942..c9aa155cd4 100644 --- a/examples/python/phi4-mm.py +++ b/examples/python/phi4-mm.py @@ -145,11 +145,8 @@ def run(args: argparse.Namespace): generator.set_inputs(inputs) start_time = time.time() - while True: + while not generator.is_done(): generator.generate_next_token() - if generator.is_done(): - break - new_token = generator.get_next_tokens()[0] print(tokenizer_stream.decode(new_token), end="", flush=True) diff --git a/src/config.h b/src/config.h index a9a8073745..51358b8192 100644 --- a/src/config.h +++ b/src/config.h @@ -301,7 +301,7 @@ struct Config { int top_k{50}; // Number of highest probability vocabulary tokens to keep for top-k-filtering that will be used by default in the generate method of the model. float top_p{}; // If set to float >0 and <1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation. float temperature{1.0f}; // Temperature to control during generation. Default is 1.0. - bool early_stopping{true}; // Whether to stop the beam search when at least num_beams sentences are finished per batch or not. + bool early_stopping{true}; // Whether to stop the beam search when at least num_beams sentences are finished per batch or not. int no_repeat_ngram_size{}; // Unused param float diversity_penalty{}; // Unused param float length_penalty{1.0f}; // Exponential penalty to the length that is used with beam-based generation. length_penalty > 0.0 promotes longer sequences, while length_penalty < 0.0 encourages shorter sequences. diff --git a/src/csharp/Generator.cs b/src/csharp/Generator.cs index 0c7eb31d81..94f83f0798 100644 --- a/src/csharp/Generator.cs +++ b/src/csharp/Generator.cs @@ -46,6 +46,15 @@ public void AppendTokenSequences(Sequences sequences) Result.VerifySuccess(NativeMethods.OgaGenerator_AppendTokenSequences(_generatorHandle, sequences.Handle)); } + /// + /// Gets the number of tokens in the generator + /// + /// The token count + public ulong TokenCount() + { + return NativeMethods.OgaGenerator_TokenCount(_generatorHandle).ToUInt64(); + } + public void GenerateNextToken() { Result.VerifySuccess(NativeMethods.OgaGenerator_GenerateNextToken(_generatorHandle)); diff --git a/src/csharp/GeneratorParams.cs b/src/csharp/GeneratorParams.cs index cc23eba0f9..4eb7e01fa7 100644 --- a/src/csharp/GeneratorParams.cs +++ b/src/csharp/GeneratorParams.cs @@ -29,14 +29,21 @@ public void SetSearchOption(string searchOption, bool value) Result.VerifySuccess(NativeMethods.OgaGeneratorParamsSetSearchBool(_generatorParamsHandle, StringUtils.ToUtf8(searchOption), value)); } - public void TryGraphCaptureWithMaxBatchSize(int maxBatchSize) + public void SetGuidance(string type, string data, bool enableFFTokens = false) { - Console.WriteLine("TryGraphCaptureWithMaxBatchSize is deprecated and will be removed in a future release."); + Result.VerifySuccess(NativeMethods.OgaGeneratorParamsSetGuidance(_generatorParamsHandle, StringUtils.ToUtf8(type), StringUtils.ToUtf8(data), enableFFTokens)); } - public void SetGuidance(string type, string data, bool enableFFTokens = false) + public double GetSearchNumber(string searchOption) { - Result.VerifySuccess(NativeMethods.OgaGeneratorParamsSetGuidance(_generatorParamsHandle, StringUtils.ToUtf8(type), StringUtils.ToUtf8(data), enableFFTokens)); + Result.VerifySuccess(NativeMethods.OgaGeneratorParamsGetSearchNumber(_generatorParamsHandle, StringUtils.ToUtf8(searchOption), out double value)); + return value; + } + + public bool GetSearchBool(string searchOption) + { + Result.VerifySuccess(NativeMethods.OgaGeneratorParamsGetSearchBool(_generatorParamsHandle, StringUtils.ToUtf8(searchOption), out bool value)); + return value; } ~GeneratorParams() diff --git a/src/csharp/NativeMethods.cs b/src/csharp/NativeMethods.cs index 09abc65bac..1ba9aac906 100644 --- a/src/csharp/NativeMethods.cs +++ b/src/csharp/NativeMethods.cs @@ -109,18 +109,27 @@ internal class NativeLib public static extern IntPtr /* OgaResult* */ OgaGeneratorParamsSetSearchBool(IntPtr /* OgaGeneratorParams* */ generatorParams, byte[] /* const char* */ searchOption, bool value); - - [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)] - public static extern IntPtr /* OgaResult* */ OgaCreateGenerator(IntPtr /* const OgaModel* */ model, - IntPtr /* const OgaGeneratorParams* */ generatorParams, - out IntPtr /* OgaGenerator** */ generator); - [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)] public static extern IntPtr /* OgaResult* */ OgaGeneratorParamsSetGuidance(IntPtr /* OgaGeneratorParams* */ generatorParams, byte[] /* const char* */ type, byte[] /* const char* */ data, bool /* boolean */ enable_ff_tokens); + [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)] + public static extern IntPtr /* OgaResult* */ OgaGeneratorParamsGetSearchNumber(IntPtr /* OgaGeneratorParams* */ generatorParams, + byte[] /* const char* */ searchOption, + out double /* const double* */ value); + + [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)] + public static extern IntPtr /* OgaResult* */ OgaGeneratorParamsGetSearchBool(IntPtr /* OgaGeneratorParams* */ generatorParams, + byte[] /* const char* */ searchOption, + out bool /* const bool* */ value); + + [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)] + public static extern IntPtr /* OgaResult* */ OgaCreateGenerator(IntPtr /* const OgaModel* */ model, + IntPtr /* const OgaGeneratorParams* */ generatorParams, + out IntPtr /* OgaGenerator** */ generator); + [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)] public static extern void OgaDestroyGenerator(IntPtr /* OgaGenerator* */ generator); @@ -157,6 +166,10 @@ internal class NativeLib public static extern IntPtr /* OgaResult* */ OgaGenerator_AppendTokenSequences(IntPtr /* OgaGenerator* */ generator, IntPtr /* const OgaSequences* */ sequences); + // This function is used to get the number of tokens in the generator. + [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)] + public static extern UIntPtr OgaGenerator_TokenCount(IntPtr /* const OgaGenerator* */ generator); + // This function is used to rewind the generator to the given newLength. [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)] diff --git a/src/cuda/search_cuda.cpp b/src/cuda/search_cuda.cpp index ad5a7c6c5a..b306d0f473 100644 --- a/src/cuda/search_cuda.cpp +++ b/src/cuda/search_cuda.cpp @@ -25,14 +25,13 @@ Search_Cuda::Search_Cuda(const GeneratorParams& params) sequence_lengths_ = params.p_device->Allocate(batch_beam_size); eos_seen_buffer_ = CudaMallocArray(batch_beam_size, &eos_seen_); - cudaMemsetAsync(eos_seen_.data(), 0, eos_seen_.size_bytes(), GetStream()); + done_cpu_ = CudaMallocHostArray(1); eos_token_ids_ = params.p_device->Allocate(params.config.model.eos_token_id.size()); copy(std::span{params.config.model.eos_token_id}, eos_token_ids_.CpuSpan()); eos_token_ids_.CopyCpuToDevice(); - done_cpu_ = CudaMallocHostArray(1); - *done_cpu_ = false; + ResetDone(); } GreedySearch_Cuda::GreedySearch_Cuda(const GeneratorParams& params) @@ -75,6 +74,11 @@ BeamSearch_Cuda::BeamSearch_Cuda(const GeneratorParams& params) BeamSearch_Cuda::~BeamSearch_Cuda() = default; +void Search_Cuda::ResetDone() { + *done_cpu_ = false; + cudaMemsetAsync(eos_seen_.data(), 0, eos_seen_.size_bytes(), GetStream()); +} + DeviceSpan Search_Cuda::GetLogits() const { return next_token_scores_; } @@ -221,8 +225,7 @@ std::span Search_Cuda::GetScores() { // Set user input tokens (batch_beam_size, sequence_length) void GreedySearch_Cuda::AppendTokens(DeviceSpan& next_tokens) { - cudaMemsetAsync(eos_seen_.data(), 0, eos_seen_.size_bytes(), GetStream()); - *done_cpu_ = false; + ResetDone(); auto next_tokens_gpu = next_tokens.Span(); cuda::Launch_AppendNextTokensToSequences(next_tokens_gpu, sequences_.GetSequences().Span(), params_->BatchBeamSize(), sequences_.GetSequenceLength(), sequences_.max_length_, GetStream()); @@ -235,8 +238,7 @@ void GreedySearch_Cuda::AppendTokens(DeviceSpan& next_tokens) { return; } - cudaMemsetAsync(eos_seen_.data(), 0, eos_seen_.size_bytes(), GetStream()); - *done_cpu_ = false; + ResetDone(); } void BeamSearch_Cuda::AppendTokens(DeviceSpan& next_tokens) { @@ -248,8 +250,7 @@ void BeamSearch_Cuda::AppendTokens(DeviceSpan& next_tokens) { } void GreedySearch_Cuda::RewindTo(size_t index) { - cudaMemsetAsync(eos_seen_.data(), 0, eos_seen_.size_bytes(), GetStream()); - *done_cpu_ = false; + ResetDone(); if (index > 0) cuda::Launch_GetLastTokens(next_tokens_.data(), sequences_.GetSequences().Span().data(), static_cast(params_->BatchBeamSize()), static_cast(index), sequences_.max_length_, GetStream()); else diff --git a/src/cuda/search_cuda.h b/src/cuda/search_cuda.h index 6d0f03f73d..b5ead4c28c 100644 --- a/src/cuda/search_cuda.h +++ b/src/cuda/search_cuda.h @@ -19,6 +19,7 @@ struct Search_Cuda : Search { cudaStreamSynchronize(GetStream()); return *done_cpu_; } // TODO: Use an event + void ResetDone(); DeviceSpan GetLogits() const override; void SetLogits(DeviceSpan logits) override; diff --git a/src/generators.cpp b/src/generators.cpp index d163eeacde..0473c51a07 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -288,6 +288,52 @@ bool GeneratorParams::IsPastPresentShareBufferEnabled(const std::string& model_t (search.num_beams == 1 || model_type == "whisper"); } +double GeneratorParams::GetSearchNumber(std::string_view name) const { + if (name == "batch_size") { + return static_cast(search.batch_size); + } else if (name == "chunk_size") { + return static_cast(search.chunk_size.value_or(0)); + } else if (name == "diversity_penalty") { + return search.diversity_penalty; + } else if (name == "length_penalty") { + return search.length_penalty; + } else if (name == "max_length") { + return static_cast(search.max_length); + } else if (name == "min_length") { + return static_cast(search.min_length); + } else if (name == "no_repeat_ngram_size") { + return static_cast(search.no_repeat_ngram_size); + } else if (name == "num_beams") { + return static_cast(search.num_beams); + } else if (name == "num_return_sequences") { + return static_cast(search.num_return_sequences); + } else if (name == "random_seed") { + return static_cast(search.random_seed); + } else if (name == "repetition_penalty") { + return search.repetition_penalty; + } else if (name == "temperature") { + return search.temperature; + } else if (name == "top_k") { + return static_cast(search.top_k); + } else if (name == "top_p") { + return search.top_p; + } else { + throw std::runtime_error(std::string(name) + " is an invalid name for GetSearchNumber."); + } +} + +bool GeneratorParams::GetSearchBool(std::string_view name) const { + if (name == "do_sample") { + return search.do_sample; + } else if (name == "early_stopping") { + return search.early_stopping; + } else if (name == "past_present_share_buffer") { + return search.past_present_share_buffer; + } else { + throw std::runtime_error(std::string(name) + " is an invalid name for GetSearchBool."); + } +} + std::unique_ptr CreateGenerator(const Model& model, const GeneratorParams& params) { return std::make_unique(model, params); } @@ -454,6 +500,10 @@ void Generator::SetRuntimeOption(const char* key, const char* value) { state_->SetRunOption(key, value); } +size_t Generator::TokenCount() const { + return static_cast(search_->GetSequenceLength()); +} + bool Generator::IsDone() { ThrowErrorIfSessionTerminated(state_->session_terminated_); if (computed_logits_) { diff --git a/src/generators.h b/src/generators.h index 1d643e198b..ccfed8db44 100644 --- a/src/generators.h +++ b/src/generators.h @@ -74,6 +74,10 @@ struct GeneratorParams : std::enable_shared_from_this, LeakChec const Config& config; // The model outlives the GeneratorParams Config::Search search{config.search}; // Copy of the search parameters from the config + // Query the params to get the value set for a param + double GetSearchNumber(std::string_view name) const; + bool GetSearchBool(std::string_view name) const; + int max_batch_size{0}; bool use_graph_capture{}; bool use_multi_profile{}; @@ -95,6 +99,7 @@ struct Generator : LeakChecked { Generator(const Model& model, const GeneratorParams& params); bool IsDone(); + size_t TokenCount() const; void AppendTokens(cpu_span input_ids); void GenerateNextToken(); void RewindToLength(size_t new_length); // Rewind state to new_length diff --git a/src/java/src/main/java/ai/onnxruntime/genai/Generator.java b/src/java/src/main/java/ai/onnxruntime/genai/Generator.java index 66489ed8df..13869c25ca 100644 --- a/src/java/src/main/java/ai/onnxruntime/genai/Generator.java +++ b/src/java/src/main/java/ai/onnxruntime/genai/Generator.java @@ -130,6 +130,19 @@ public void appendTokenSequences(Sequences sequences) throws GenAIException { appendTokenSequences(nativeHandle, sequences.nativeHandle()); } + /** + * Returns the token count in the generator. + * + * @throws GenAIException If the call to the GenAI native API fails. + */ + public long tokenCount() throws GenAIException { + if (nativeHandle == 0) { + throw new IllegalStateException("Instance has been freed and is invalid"); + } + + return tokenCount(nativeHandle); + } + /** * Rewinds the generator to the given length. This is useful when the user wants to rewind the * generator to a specific length and continue generating from that point. @@ -280,6 +293,8 @@ private native void setModelInput(long nativeHandle, String inputName, long tens private native void appendTokenSequences(long nativeHandle, long sequencesHandle) throws GenAIException; + private native long tokenCount(long nativeHandle) throws GenAIException; + private native void rewindTo(long nativeHandle, long newLength) throws GenAIException; private native void generateNextTokenNative(long nativeHandle) throws GenAIException; diff --git a/src/java/src/main/java/ai/onnxruntime/genai/GeneratorParams.java b/src/java/src/main/java/ai/onnxruntime/genai/GeneratorParams.java index 7bf8306f4a..c6fd3f4945 100644 --- a/src/java/src/main/java/ai/onnxruntime/genai/GeneratorParams.java +++ b/src/java/src/main/java/ai/onnxruntime/genai/GeneratorParams.java @@ -29,7 +29,7 @@ public GeneratorParams(Model model) throws GenAIException { } /** - * Set seach option with double value. + * Set search option with double value. * * @param optionName The option name. * @param value The option value. @@ -58,6 +58,36 @@ public void setSearchOption(String optionName, boolean value) throws GenAIExcept setSearchOptionBool(nativeHandle, optionName, value); } + /** + * Get search option with numerical value. + * + * @param optionName The option name. + * @return The option value. + * @throws GenAIException If the call to the GenAI native API fails. + */ + public double getSearchNumber(String optionName) throws GenAIException { + if (nativeHandle == 0) { + throw new IllegalStateException("Instance has been freed and is invalid"); + } + + return getSearchNumber(nativeHandle, optionName); + } + + /** + * Get search option with boolean value. + * + * @param optionName The option name. + * @return The option value. + * @throws GenAIException If the call to the GenAI native API fails. + */ + public boolean getSearchBool(String optionName) throws GenAIException { + if (nativeHandle == 0) { + throw new IllegalStateException("Instance has been freed and is invalid"); + } + + return getSearchBool(nativeHandle, optionName); + } + @Override public void close() { if (nativeHandle != 0) { @@ -87,4 +117,8 @@ private native void setSearchOptionNumber(long nativeHandle, String optionName, private native void setSearchOptionBool(long nativeHandle, String optionName, boolean value) throws GenAIException; + + private native double getSearchNumber(long nativeHandle, String optionName) throws GenAIException; + + private native boolean getSearchBool(long nativeHandle, String optionName) throws GenAIException; } diff --git a/src/java/src/main/native/ai_onnxruntime_genai_Generator.cpp b/src/java/src/main/native/ai_onnxruntime_genai_Generator.cpp index c75e397917..5327683ecb 100644 --- a/src/java/src/main/native/ai_onnxruntime_genai_Generator.cpp +++ b/src/java/src/main/native/ai_onnxruntime_genai_Generator.cpp @@ -67,6 +67,13 @@ Java_ai_onnxruntime_genai_Generator_appendTokens(JNIEnv* env, jobject thiz, jlon env->ReleaseIntArrayElements(token_ids, tokens, JNI_ABORT); } +JNIEXPORT jlong JNICALL +Java_ai_onnxruntime_genai_Generator_tokenCount(JNIEnv* env, jobject thiz, jlong native_handle) { + OgaGenerator* generator = reinterpret_cast(native_handle); + size_t count = OgaGenerator_TokenCount(generator); + return static_cast(count); +} + JNIEXPORT jboolean JNICALL Java_ai_onnxruntime_genai_Generator_isDone(JNIEnv* env, jobject thiz, jlong native_handle) { return OgaGenerator_IsDone(reinterpret_cast(native_handle)); diff --git a/src/java/src/main/native/ai_onnxruntime_genai_GeneratorParams.cpp b/src/java/src/main/native/ai_onnxruntime_genai_GeneratorParams.cpp index 3664c0e3e8..907ab93040 100644 --- a/src/java/src/main/native/ai_onnxruntime_genai_GeneratorParams.cpp +++ b/src/java/src/main/native/ai_onnxruntime_genai_GeneratorParams.cpp @@ -43,3 +43,23 @@ Java_ai_onnxruntime_genai_GeneratorParams_setSearchOptionBool(JNIEnv* env, jobje ThrowIfError(env, OgaGeneratorParamsSetSearchBool(generator_params, name, value)); } + +JNIEXPORT jdouble JNICALL +Java_ai_onnxruntime_genai_GeneratorParams_getSearchNumber(JNIEnv* env, jobject thiz, jlong native_handle, jstring option_name) { + OgaGeneratorParams* generator_params = reinterpret_cast(native_handle); + CString name{env, option_name}; + double value; + + ThrowIfError(env, OgaGeneratorParamsGetSearchNumber(generator_params, name, &value)); + return static_cast(value); +} + +JNIEXPORT jboolean JNICALL +Java_ai_onnxruntime_genai_GeneratorParams_getSearchBool(JNIEnv* env, jobject thiz, jlong native_handle, jstring option_name) { + OgaGeneratorParams* generator_params = reinterpret_cast(native_handle); + CString name{env, option_name}; + bool value; + + ThrowIfError(env, OgaGeneratorParamsGetSearchBool(generator_params, name, &value)); + return static_cast(value); +} diff --git a/src/java/src/test/java/ai/onnxruntime/genai/GenerationTest.java b/src/java/src/test/java/ai/onnxruntime/genai/GenerationTest.java index beb0a12ad9..bc25395a1f 100644 --- a/src/java/src/test/java/ai/onnxruntime/genai/GenerationTest.java +++ b/src/java/src/test/java/ai/onnxruntime/genai/GenerationTest.java @@ -138,6 +138,11 @@ public void testWithInputIds() throws GenAIException { try (Generator generator = new Generator(model, params); ) { generator.appendTokens(inputIDs); + + assertEquals(params.getSearchNumber("max_length"), maxLength); + assertEquals(params.getSearchBool("early_stopping"), true); + assertEquals(generator.tokenCount(), 4); + while (!generator.isDone()) { generator.generateNextToken(); } @@ -148,6 +153,7 @@ public void testWithInputIds() throws GenAIException { assertEquals(outputIds[j], expectedOutput[i * maxLength + j]); } } + assertEquals(generator.tokenCount(), generator.getSequence(0).length); } } } diff --git a/src/java/src/test/java/ai/onnxruntime/genai/GeneratorParamsTest.java b/src/java/src/test/java/ai/onnxruntime/genai/GeneratorParamsTest.java index cd8698de38..944d3c4047 100644 --- a/src/java/src/test/java/ai/onnxruntime/genai/GeneratorParamsTest.java +++ b/src/java/src/test/java/ai/onnxruntime/genai/GeneratorParamsTest.java @@ -12,7 +12,7 @@ public class GeneratorParamsTest { @Test public void testValidSearchOption() throws GenAIException { - // test setting an invalid search option throws a GenAIException + // test setting a valid search option try (SimpleGenAI generator = new SimpleGenAI(TestUtils.tinyGpt2ModelPath()); GeneratorParams params = generator.createGeneratorParams(); ) { params.setSearchOption("early_stopping", true); // boolean diff --git a/src/objectivec/error_utils.h b/src/objectivec/error_utils.h index 8c73faa971..43672ee465 100644 --- a/src/objectivec/error_utils.h +++ b/src/objectivec/error_utils.h @@ -23,7 +23,9 @@ void OGASaveExceptionToError(const std::exception& e, NSError** error); } #define OGA_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) OGA_OBJC_API_IMPL_CATCH(error, NO) - +#define OGA_OBJC_API_IMPL_CATCH_RETURNING_DOUBLE (error) OGA_OBJC_API_IMPL_CATCH(error, double(0.0)) +#define OGA_OBJC_API_IMPL_CATCH_RETURNING_INT32_T (error) OGA_OBJC_API_IMPL_CATCH(error, int32_t(0)) #define OGA_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) OGA_OBJC_API_IMPL_CATCH(error, nil) +#define OGA_OBJC_API_IMPL_CATCH_RETURNING_SIZE_T (error) OGA_OBJC_API_IMPL_CATCH(error, size_t(-1)) NS_ASSUME_NONNULL_END diff --git a/src/objectivec/include/ort_genai_objc.h b/src/objectivec/include/ort_genai_objc.h index 2e3d07c836..c3f23655c7 100644 --- a/src/objectivec/include/ort_genai_objc.h +++ b/src/objectivec/include/ort_genai_objc.h @@ -146,6 +146,27 @@ typedef NS_ENUM(NSInteger, OGAElementType) { - (nullable instancetype)initWithModel:(OGAModel*)model error:(NSError**)error NS_DESIGNATED_INITIALIZER; +/** + * Return the int representation of the BOS token. + * + * @return The BOS token id + */ +- (int32_t)getBosTokenId:(NSError**)error; + +/** + * Return the int representations of the array of EOS tokens. + * + * @return The array of EOS token ids + */ +- (nullable NSArray*)getEosTokenIds:(NSError**)error; + +/** + * Return the int representation of the PAD token. + * + * @return The PAD token id + */ +- (int32_t)getPadTokenId:(NSError**)error; + /** * Encode text to sequences * @@ -273,6 +294,23 @@ typedef NS_ENUM(NSInteger, OGAElementType) { - (BOOL)setSearchOption:(NSString*)key boolValue:(BOOL)value error:(NSError**)error; + +/** + * Get numerical value of option. + * @param key The option key. + * @param error Optional error information set if an error occurs. + * @return The option value. + */ +- (double)getSearchNumber:(NSString*)key + error:(NSError**)error; +/** + * Get boolean value of option. + * @param key The option key. + * @param error Optional error information set if an error occurs. + * @return The option value. + */ +- (BOOL)getSearchBool:(NSString*)key + error:(NSError**)error; @end /** @@ -331,6 +369,13 @@ typedef NS_ENUM(NSInteger, OGAElementType) { */ - (BOOL)appendTokens:(NSArray*)tokens error:(NSError**)error; +/** + * Get the number of tokens in the generator. + * @param error Optional error information set if an error occurs. + * @return The number of tokens in the generator + */ +- (size_t)tokenCount:(NSError**)error; + /** * Rewinds the generator to the given length. * @param newLength The desired length in tokens after rewinding. diff --git a/src/objectivec/oga_generator.mm b/src/objectivec/oga_generator.mm index b5ca0e941c..6a9b9b3669 100644 --- a/src/objectivec/oga_generator.mm +++ b/src/objectivec/oga_generator.mm @@ -67,6 +67,13 @@ - (BOOL)appendTokens:(NSArray*)tokens error:(NSError**)error { OGA_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) } +- (size_t)tokenCount:(NSError**)error { + try { + return _generator->TokenCount(); + } + OGA_OBJC_API_IMPL_CATCH_RETURNING_SIZE_T(error) +} + - (BOOL)rewindTo:(size_t)newLength error:(NSError**)error { try { _generator->RewindTo(newLength); @@ -110,7 +117,7 @@ - (size_t)sequenceCountAtIndex:(size_t)index error:(NSError**)error { try { return _generator->GetSequenceCount(index); } - OGA_OBJC_API_IMPL_CATCH(error, size_t(-1)) + OGA_OBJC_API_IMPL_CATCH_RETURNING_SIZE_T(error) } + (void)shutdown { diff --git a/src/objectivec/oga_generator_params.mm b/src/objectivec/oga_generator_params.mm index 1ae8719d41..25ecfd0386 100644 --- a/src/objectivec/oga_generator_params.mm +++ b/src/objectivec/oga_generator_params.mm @@ -37,6 +37,20 @@ - (BOOL)setSearchOption:(NSString*)key boolValue:(BOOL)value error:(NSError**)er OGA_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) } +- (double)getSearchNumber:(NSString*)key error:(NSError**)error { + try { + return _generatorParams->GetSearchNumber([key UTF8String]); + } + OGA_OBJC_API_IMPL_CATCH_RETURNING_DOUBLE(error) +} + +- (BOOL)getSearchBool:(NSString*)key error:(NSError**)error { + try { + return _generatorParams->GetSearchBool([key UTF8String]); + } + OGA_OBJC_API_IMPL_CATCH_RETURNING_BOOL(error) +} + - (OgaGeneratorParams&)CXXAPIOgaGeneratorParams { return *(_generatorParams.get()); } diff --git a/src/objectivec/oga_sequences.mm b/src/objectivec/oga_sequences.mm index 7dd717b360..6ff7c74250 100644 --- a/src/objectivec/oga_sequences.mm +++ b/src/objectivec/oga_sequences.mm @@ -30,7 +30,7 @@ - (size_t)getCountWithError:(NSError**)error { try { return _sequences->Count(); } - OGA_OBJC_API_IMPL_CATCH(error, size_t(-1)) + OGA_OBJC_API_IMPL_CATCH_RETURNING_SIZE_T(error) } - (nullable const int32_t*)sequenceDataAtIndex:(size_t)index error:(NSError**)error { @@ -44,7 +44,7 @@ - (size_t)sequenceCountAtIndex:(size_t)index error:(NSError**)error { try { return _sequences->SequenceCount(index); } - OGA_OBJC_API_IMPL_CATCH(error, size_t(-1)) + OGA_OBJC_API_IMPL_CATCH_RETURNING_SIZE_T(error) } - (OgaSequences&)CXXAPIOgaSequences { diff --git a/src/objectivec/oga_tokenizer.mm b/src/objectivec/oga_tokenizer.mm index c961faabc8..a3eb27ffd2 100644 --- a/src/objectivec/oga_tokenizer.mm +++ b/src/objectivec/oga_tokenizer.mm @@ -21,6 +21,35 @@ - (nullable instancetype)initWithModel:(OGAModel*)model error:(NSError**)error { OGA_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) } +- (int32_t)getBosTokenId:(NSError**)error { + try { + return _tokenizer->GetBosTokenId(); + } + OGA_OBJC_API_IMPL_CATCH_RETURNING_INT32_T(error) +} + +- (nullable NSArray*)getEosTokenIds:(NSError**)error { + try { + std::vector eos_ids = _tokenizer->GetEosTokenIds(); + NSMutableArray* result = [NSMutableArray arrayWithCapacity:eos_ids.size()]; + if (!result) { + return nil; + } + for (int32_t eos_id : eos_ids) { + [result addObject:@(eos_id)]; + } + return result; + } + OGA_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error) +} + +- (int32_t)getPadTokenId:(NSError**)error { + try { + return _tokenizer->GetPadTokenId(); + } + OGA_OBJC_API_IMPL_CATCH_RETURNING_INT32_T(error) +} + - (nullable OGASequences*)encode:(NSString*)str error:(NSError**)error { OGASequences* sequences = [[OGASequences alloc] initWithError:error]; if (!sequences) { diff --git a/src/ort_genai.h b/src/ort_genai.h index dc38a1b967..d6ea59d3f1 100644 --- a/src/ort_genai.h +++ b/src/ort_genai.h @@ -324,12 +324,10 @@ struct OgaTokenizer : OgaAbstract { } #else std::vector GetEosTokenIds() const { - std::vector eos_ids; const int32_t* eos_ids_ptr; size_t count; OgaCheckResult(OgaTokenizerGetEosTokenIds(this, &eos_ids_ptr, &count)); - eos_ids.assign(eos_ids_ptr, eos_ids_ptr + count); - return eos_ids; + return std::vector(eos_ids_ptr, eos_ids_ptr + count); } #endif @@ -426,14 +424,22 @@ struct OgaGeneratorParams : OgaAbstract { OgaCheckResult(OgaGeneratorParamsSetSearchBool(this, name, value)); } - void TryGraphCaptureWithMaxBatchSize(int /*max_batch_size*/) { - printf("TryGraphCaptureWithMaxBatchSize is deprecated and will be removed in a future release\n"); - } - void SetGuidance(const char* type, const char* data, bool enable_ff_tokens = false) { OgaCheckResult(OgaGeneratorParamsSetGuidance(this, type, data, enable_ff_tokens)); } + double GetSearchNumber(const char* name) const { + double value; + OgaCheckResult(OgaGeneratorParamsGetSearchNumber(this, name, &value)); + return value; + } + + bool GetSearchBool(const char* name) const { + bool value; + OgaCheckResult(OgaGeneratorParamsGetSearchBool(this, name, &value)); + return value; + } + static void operator delete(void* p) { OgaDestroyGeneratorParams(reinterpret_cast(p)); } }; @@ -448,6 +454,10 @@ struct OgaGenerator : OgaAbstract { return OgaGenerator_IsDone(this); } + bool IsSessionTerminated() const { + return OgaGenerator_IsSessionTerminated(this); + } + void SetModelInput(const char* name, OgaTensor& tensor) { OgaCheckResult(OgaGenerator_SetModelInput(this, name, &tensor)); } @@ -470,8 +480,8 @@ struct OgaGenerator : OgaAbstract { } #endif - bool IsSessionTerminated() const { - return OgaGenerator_IsSessionTerminated(this); + size_t TokenCount() const { + return OgaGenerator_TokenCount(this); } void GenerateNextToken() { @@ -485,6 +495,13 @@ struct OgaGenerator : OgaAbstract { OgaCheckResult(OgaGenerator_GetNextTokens(this, &out, &out_count)); return {out, out_count}; } +#else + std::vector GetNextTokens() { + const int32_t* out; + size_t out_count; + OgaCheckResult(OgaGenerator_GetNextTokens(this, &out, &out_count)); + return std::vector(out, out + out_count); + } #endif void RewindTo(size_t new_length) { diff --git a/src/ort_genai_c.cpp b/src/ort_genai_c.cpp index e18b424b0b..41c70cdb3c 100644 --- a/src/ort_genai_c.cpp +++ b/src/ort_genai_c.cpp @@ -388,16 +388,23 @@ OgaResult* OGA_API_CALL OgaGeneratorParamsSetSearchBool(OgaGeneratorParams* para OGA_CATCH } -OgaResult* OGA_API_CALL OgaGeneratorParamsTryGraphCaptureWithMaxBatchSize(OgaGeneratorParams* params, int32_t max_batch_size) { +OgaResult* OGA_API_CALL OgaGeneratorParamsSetGuidance(OgaGeneratorParams* params, const char* type, const char* data, bool enable_ff_tokens) { OGA_TRY - printf("TryGraphCaptureWithMaxBatchSize is deprecated and will be removed in a future release\n"); + params->SetGuidance(type, data, enable_ff_tokens); return nullptr; OGA_CATCH } -OgaResult* OGA_API_CALL OgaGeneratorParamsSetGuidance(OgaGeneratorParams* params, const char* type, const char* data, bool enable_ff_tokens) { +OgaResult* OGA_API_CALL OgaGeneratorParamsGetSearchNumber(const OgaGeneratorParams* params, const char* name, double* value) { OGA_TRY - params->SetGuidance(type, data, enable_ff_tokens); + *value = params->GetSearchNumber(name); + return nullptr; + OGA_CATCH +} + +OgaResult* OGA_API_CALL OgaGeneratorParamsGetSearchBool(const OgaGeneratorParams* params, const char* name, bool* value) { + OGA_TRY + *value = params->GetSearchBool(name); return nullptr; OGA_CATCH } @@ -457,6 +464,10 @@ OgaResult* OGA_API_CALL OgaGenerator_AppendTokens(OgaGenerator* generator, const OGA_CATCH } +size_t OGA_API_CALL OgaGenerator_TokenCount(const OgaGenerator* generator) { + return generator->TokenCount(); +} + OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken(OgaGenerator* generator) { OGA_TRY generator->GenerateNextToken(); diff --git a/src/ort_genai_c.h b/src/ort_genai_c.h index 7b7c767587..de7d4fa484 100644 --- a/src/ort_genai_c.h +++ b/src/ort_genai_c.h @@ -411,9 +411,23 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateGeneratorParams(const OgaModel* mode */ OGA_EXPORT void OGA_API_CALL OgaDestroyGeneratorParams(OgaGeneratorParams* params); +/** + * \brief Set a numerical value for a search parameter + * \param[in] params The generator params to set. + * \param[in] name The name of the search parameter. + * \param[in] value The value of the search parameter. + * \return OgaResult containing the error message if setting the generator params failed. + */ OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetSearchNumber(OgaGeneratorParams* params, const char* name, double value); + +/** + * \brief Set a boolean value for a search parameter + * \param[in] params The generator params to set. + * \param[in] name The name of the search parameter. + * \param[in] value The value of the search parameter. + * \return OgaResult containing the error message if setting the generator params failed. + */ OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetSearchBool(OgaGeneratorParams* params, const char* name, bool value); -OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsTryGraphCaptureWithMaxBatchSize(OgaGeneratorParams* params, int32_t max_batch_size); /** * \brief Sets the guidance type and data for the Generator params @@ -425,6 +439,24 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsTryGraphCaptureWithMaxBatch */ OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetGuidance(OgaGeneratorParams* params, const char* type, const char* data, bool enable_ff_tokens); +/** + * \brief Get a numerical value for a search parameter + * \param[in] params The generator params to set. + * \param[in] name The name of the search parameter. + * \param[out] value The value of the search parameter. + * \return OgaResult containing the error message if setting the generator params failed. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsGetSearchNumber(const OgaGeneratorParams* params, const char* name, double* value); + +/** + * \brief Get a boolean value for a search parameter + * \param[in] params The generator params to set. + * \param[in] name The name of the search parameter. + * \param[out] value The value of the search parameter. + * \return OgaResult containing the error message if setting the generator params failed. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsGetSearchBool(const OgaGeneratorParams* params, const char* name, bool* value); + /** * \brief Creates a generator from the given model and generator params. * \param[in] model The model to use for generation. @@ -446,6 +478,12 @@ OGA_EXPORT void OGA_API_CALL OgaDestroyGenerator(OgaGenerator* generator); * \return True if the generator has finished generating all the sequences, false otherwise. */ OGA_EXPORT bool OGA_API_CALL OgaGenerator_IsDone(OgaGenerator* generator); + +/** + * \brief Returns true if the session has been terminated. + * \param[in] generator The generator to add the inputs to. + * \return True if the session has been terminated, false otherwise. + */ OGA_EXPORT bool OGA_API_CALL OgaGenerator_IsSessionTerminated(const OgaGenerator* generator); /** @@ -481,6 +519,13 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_AppendTokenSequences(OgaGenerato */ OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_AppendTokens(OgaGenerator* generator, const int32_t* input_ids, size_t input_ids_count); +/** + * \brief Returns the number of tokens in the generator + * \param[in] generator The generator containing the appended tokens. + * \return The number of tokens that have been added. + */ +OGA_EXPORT size_t OGA_API_CALL OgaGenerator_TokenCount(const OgaGenerator* generator); + /** * \brief Computes the logits from the model based on the input ids and the past state. The computed logits are stored in the generator. * \param[in] generator The generator to compute the logits for. @@ -497,6 +542,13 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken(OgaGenerator* */ OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_GetNextTokens(const OgaGenerator* generator, const int32_t** out, size_t* out_count); +/** + * \brief Set a runtime option's name and value. + * \param[in] generator The generator to rewind to the given length. + * \param[in] key The runtime option's name + * \param[in] value The runtime option's value + * \return OgaResult containing the error message if setting the runtime option failed. + */ OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_SetRuntimeOption(OgaGenerator* generator, const char* key, const char* value); /** @@ -611,17 +663,27 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaUpdateTokenizerOptions( size_t num_options); /** - * Return the int representation of the BOS token + * \brief Return the int representation of the BOS token + * \param[in] tokenizer The tokenizer to read from + * \param[out] token_id The BOS token id + * \return OgaResult containing the error message if returning the BOS token id fails. */ OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerGetBosTokenId(const OgaTokenizer* tokenizer, int32_t* token_id); /** - * Return an array containing the int representations of the EOS tokens. The array is owned by the tokenizer and will be freed when the tokenizer is destroyed. + * \brief Return an array containing the int representations of the EOS token ids. The array is owned by the tokenizer and will be freed when the tokenizer is destroyed. + * \param[in] tokenizer The tokenizer to read from + * \param[out] eos_token_ids The array of EOS token ids + * \param[out] token_count The length of the array + * \return OgaResult containing the error message if returning the EOS token ids fails. */ OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerGetEosTokenIds(const OgaTokenizer* tokenizer, const int32_t** eos_token_ids, size_t* token_count); /** - * Return the int representation of the BOS token + * \brief Return the int representation of the PAD token + * \param[in] tokenizer The tokenizer to read from + * \param[out] token_id The PAD token id + * \return OgaResult containing the error message if returning the PAD token id fails. */ OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerGetPadTokenId(const OgaTokenizer* tokenizer, int32_t* token_id); diff --git a/src/python/python.cpp b/src/python/python.cpp index 547a116b0f..3626006b32 100644 --- a/src/python/python.cpp +++ b/src/python/python.cpp @@ -192,14 +192,32 @@ struct PyGeneratorParams { } } - void TryGraphCaptureWithMaxBatchSize(pybind11::int_ max_batch_size) { - std::cerr << "TryGraphCaptureWithMaxBatchSize is deprecated and will be removed in a future release" << std::endl; - } - void SetGuidance(const std::string& type, const std::string& data, bool enable_ff_tokens = false) { params_->SetGuidance(type.c_str(), data.c_str(), enable_ff_tokens); } + pybind11::dict GetSearchOptions() { + pybind11::dict d; + d["batch_size"] = params_->GetSearchNumber("batch_size"); + d["chunk_size"] = params_->GetSearchNumber("chunk_size"); + d["diversity_penalty"] = params_->GetSearchNumber("diversity_penalty"); + d["do_sample"] = params_->GetSearchBool("do_sample"); + d["early_stopping"] = params_->GetSearchBool("early_stopping"); + d["length_penalty"] = params_->GetSearchNumber("length_penalty"); + d["max_length"] = params_->GetSearchNumber("max_length"); + d["min_length"] = params_->GetSearchNumber("min_length"); + d["no_repeat_ngram_size"] = params_->GetSearchNumber("no_repeat_ngram_size"); + d["num_beams"] = params_->GetSearchNumber("num_beams"); + d["num_return_sequences"] = params_->GetSearchNumber("num_return_sequences"); + d["past_present_share_buffer"] = params_->GetSearchBool("past_present_share_buffer"); + d["random_seed"] = params_->GetSearchNumber("random_seed"); + d["repetition_penalty"] = params_->GetSearchNumber("repetition_penalty"); + d["temperature"] = params_->GetSearchNumber("temperature"); + d["top_k"] = params_->GetSearchNumber("top_k"); + d["top_p"] = params_->GetSearchNumber("top_p"); + return d; + } + std::vector refs_; // References to data we want to ensure doesn't get garbage collected }; @@ -240,6 +258,10 @@ struct PyGenerator { generator_->AppendTokens(ToSpan(tokens)); } + size_t TokenCount() const { + return generator_->TokenCount(); + } + pybind11::array_t GetLogits() { return ToNumpy(*generator_->GetLogits()); } @@ -316,11 +338,11 @@ PYBIND11_MODULE(onnxruntime_genai, m) { pybind11::class_(m, "GeneratorParams") .def(pybind11::init()) - .def("try_graph_capture_with_max_batch_size", &PyGeneratorParams::TryGraphCaptureWithMaxBatchSize) .def("set_search_options", &PyGeneratorParams::SetSearchOptions) // See config.h 'struct Search' for the options .def("set_guidance", &PyGeneratorParams::SetGuidance, pybind11::arg("type"), pybind11::arg("data"), - pybind11::arg("enable_ff_tokens") = false); + pybind11::arg("enable_ff_tokens") = false) + .def("get_search_options", &PyGeneratorParams::GetSearchOptions); pybind11::class_(m, "TokenizerStream") .def("decode", [](OgaTokenizerStream& t, int32_t token) { return t.Decode(token); }); @@ -461,6 +483,7 @@ PYBIND11_MODULE(onnxruntime_genai, m) { .def("set_model_input", &PyGenerator::SetModelInput) .def("append_tokens", pybind11::overload_cast&>(&PyGenerator::AppendTokens)) .def("append_tokens", pybind11::overload_cast(&PyGenerator::AppendTokens)) + .def("token_count", &PyGenerator::TokenCount) .def("get_logits", &PyGenerator::GetLogits) .def("set_logits", &PyGenerator::SetLogits) .def("generate_next_token", &PyGenerator::GenerateNextToken) diff --git a/src/search.cpp b/src/search.cpp index 1d69c30115..3e9d545c56 100644 --- a/src/search.cpp +++ b/src/search.cpp @@ -49,6 +49,17 @@ BeamSearch_Cpu::BeamSearch_Cpu(const GeneratorParams& params) BeamSearch_Cpu::~BeamSearch_Cpu() = default; +void Search_Cpu::ResetDone() { + // Reset done count/state + done_ = false; +} + +void GreedySearch_Cpu::ResetDone() { + Search_Cpu::ResetDone(); + not_done_count_ = params_->search.batch_size; + memset(eos_seen_.data(), 0, eos_seen_.size_bytes()); +} + DeviceSpan Search_Cpu::GetLogits() const { return next_token_scores_; } @@ -333,16 +344,13 @@ void GreedySearch_Cpu::AppendTokens(DeviceSpan& next_tokens) { } AppendNextTokensToSequences(); } - // Reset done count/state - done_ = false; - not_done_count_ = params_->search.batch_size; - memset(eos_seen_.data(), 0, eos_seen_.size_bytes()); + + ResetDone(); } void GreedySearch_Cpu::RewindTo(size_t index) { - done_ = false; - not_done_count_ = params_->search.batch_size; - memset(eos_seen_.data(), 0, eos_seen_.size_bytes()); + ResetDone(); + // Set next tokens to the last tokens in the sequence if (index > 0) { for (int i = 0; i < params_->BatchBeamSize(); i++) { diff --git a/src/search.h b/src/search.h index 9d456368c1..b5d19de908 100644 --- a/src/search.h +++ b/src/search.h @@ -44,6 +44,7 @@ struct Search_Cpu : Search { DeviceSpan GetSequenceLengths() override { return sequence_lengths_; } bool IsDone() const override { return done_; } + void ResetDone(); DeviceSpan GetLogits() const override; void SetLogits(DeviceSpan logits) override; @@ -75,6 +76,7 @@ struct GreedySearch_Cpu : Search_Cpu { void SampleTopKTopP(int /*k*/, float /*p*/, float /*temperature*/) override; // Used by continuous decoding search. + void ResetDone(); void AppendTokens(DeviceSpan& next_tokens) override; void RewindTo(size_t index) override; diff --git a/test/c_api_tests.cpp b/test/c_api_tests.cpp index 08af900b56..330ff1b43c 100644 --- a/test/c_api_tests.cpp +++ b/test/c_api_tests.cpp @@ -338,11 +338,8 @@ TEST(CAPITests, EndToEndPhiBatch) { auto generator = OgaGenerator::Create(*model, *params); generator->AppendTokenSequences(*input_sequences); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { - break; - } } // Decode The Batch @@ -519,11 +516,8 @@ TEST(CAPITests, EndToEndPhi) { auto generator = OgaGenerator::Create(*model, *params); generator->AppendTokenSequences(*input_sequence); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { - break; - } } // Decode The Batch @@ -561,11 +555,12 @@ TEST(CAPITests, EndToEndPhiEOSPAD) { auto generator = OgaGenerator::Create(*model, *params); generator->AppendTokenSequences(*input_sequence); - while (true) { + ASSERT_EQ(static_cast(params->GetSearchNumber("max_length")), 40); + ASSERT_EQ(params->GetSearchBool("early_stopping"), true); + ASSERT_EQ(static_cast(generator->TokenCount()), static_cast(generator->GetSequenceCount(0))); + + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { - break; - } } // Decode The Batch @@ -582,6 +577,7 @@ TEST(CAPITests, EndToEndPhiEOSPAD) { const auto* sequence_data = generator->GetSequenceData(0); ASSERT_LE(sequence_length, 40); + ASSERT_EQ(static_cast(generator->TokenCount()), static_cast(generator->GetSequenceCount(0))); const auto* expected_output_start = &expected_output[0]; EXPECT_TRUE(0 == std::memcmp(expected_output_start, sequence_data, sequence_length * sizeof(int32_t))); @@ -658,11 +654,8 @@ TEST(CAPITests, LoadModelFromMemory) { auto generator = OgaGenerator::Create(*model, *params); generator->AppendTokenSequences(*input_sequence); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { - break; - } } // Decode The Batch @@ -738,11 +731,8 @@ TEST(CAPITests, GreedySearchGptFp32CAPI) { auto generator = OgaGenerator::Create(*model, *params); generator->AppendTokens(input_ids.data(), input_ids.size()); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { - break; - } } // Verify outputs match expected outputs @@ -903,11 +893,8 @@ TEST(CAPITests, SetTerminate) { auto GenerateOutput = [](OgaGenerator* generator, std::unique_ptr tokenizer_stream) { EXPECT_THROW({ - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { - break; - } } }, std::runtime_error); }; @@ -968,11 +955,8 @@ struct Phi2Test { auto generator = OgaGenerator::Create(*model_, *params_); generator->AppendTokenSequences(*input_sequences_); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { - break; - } } // Decode One at a time @@ -1130,11 +1114,8 @@ TEST(CAPITests, AdaptersTest) { auto generator = OgaGenerator::Create(*model, *params); generator->AppendTokenSequences(*input_sequences); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { - break; - } } auto logits = generator->GetOutput("logits"); @@ -1156,11 +1137,8 @@ TEST(CAPITests, AdaptersTest) { generator->SetActiveAdapter(*adapters, "adapters_a_and_b"); generator->AppendTokenSequences(*input_sequences); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { - break; - } } auto logits = generator->GetOutput("logits"); @@ -1210,11 +1188,8 @@ TEST(CAPITests, AdaptersTestMultipleAdapters) { generator->SetActiveAdapter(*adapters, "adapter_b"); generator->AppendTokenSequences(*input_sequences); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { - break; - } } } @@ -1256,11 +1231,8 @@ TEST(CAPITests, BatchedRewindGptFp32CAPI) { auto generator = OgaGenerator::Create(*model, *params); generator->AppendTokens(input_ids.data(), input_ids.size()); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { - break; - } } // Verify outputs match expected outputs @@ -1278,11 +1250,8 @@ TEST(CAPITests, BatchedRewindGptFp32CAPI) { generator->RewindTo(0); generator->AppendTokens(input_ids.data(), input_ids.size()); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { - break; - } } // Verify outputs match expected outputs @@ -1316,11 +1285,8 @@ TEST(CAPITests, RewindGptFp32CAPI) { auto generator = OgaGenerator::Create(*model, *params); generator->AppendTokens(input_ids.data(), input_ids.size()); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { - break; - } } // Verify outputs match expected outputs @@ -1334,11 +1300,8 @@ TEST(CAPITests, RewindGptFp32CAPI) { // Rewind to length 5 and verify same output generator->RewindTo(5); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { - break; - } } // Verify outputs match expected outputs @@ -1353,11 +1316,8 @@ TEST(CAPITests, RewindGptFp32CAPI) { std::vector next_ids{731, 731}; generator->AppendTokens(next_ids.data(), next_ids.size()); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { - break; - } } // Verify outputs match expected outputs @@ -1386,11 +1346,8 @@ TEST(CAPITests, SetGuidance) { auto generator = OgaGenerator::Create(*model, *params); generator->AppendTokenSequences(*input_sequences); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { - break; - } } auto out_string = tokenizer->Decode(generator->GetSequenceData(0), generator->GetSequenceCount(0)); auto output = std::string(out_string).substr(std::string(input_string).size()); diff --git a/test/csharp/TestOnnxRuntimeGenAIAPI.cs b/test/csharp/TestOnnxRuntimeGenAIAPI.cs index 3b1c72d395..4e6a05b5cd 100644 --- a/test/csharp/TestOnnxRuntimeGenAIAPI.cs +++ b/test/csharp/TestOnnxRuntimeGenAIAPI.cs @@ -158,16 +158,16 @@ public void TestGreedySearch() using (var generator = new Generator(model, generatorParams)) { Assert.NotNull(generator); - generator.AppendTokens(inputIDs); + Assert.False(generator.IsDone()); - while (true) + Assert.Equal(generatorParams.GetSearchNumber("max_length"), maxLength); + Assert.Equal(generatorParams.GetSearchBool("early_stopping"), true); + Assert.Equal((int)generator.TokenCount(), generator.GetSequence(0).Length); + + while (!generator.IsDone()) { generator.GenerateNextToken(); - if (generator.IsDone()) - { - break; - } } for (ulong i = 0; i < batchSize; i++) @@ -175,6 +175,7 @@ public void TestGreedySearch() var sequence = generator.GetSequence(i).ToArray(); var expectedSequence = expectedOutput.Skip((int)i * (int)maxLength).Take((int)maxLength); Assert.Equal(expectedSequence, sequence); + Assert.Equal((int)generator.TokenCount(), generator.GetSequence(i).Length); } } } @@ -216,13 +217,9 @@ public void TestLoadModelFromMemory() generator.AppendTokens(inputIDs); Assert.False(generator.IsDone()); - while (true) + while (!generator.IsDone()) { generator.GenerateNextToken(); - if (generator.IsDone()) - { - break; - } } for (ulong i = 0; i < batchSize; i++) @@ -296,13 +293,9 @@ public void TestTopKSearch() generator.AppendTokenSequences(sequences); - while (true) + while (!generator.IsDone()) { generator.GenerateNextToken(); - if (generator.IsDone()) - { - break; - } } for (ulong i = 0; i < batchSize; i++) @@ -359,13 +352,9 @@ public void TestTopPSearch() generator.AppendTokenSequences(sequences); - while (true) + while (!generator.IsDone()) { generator.GenerateNextToken(); - if (generator.IsDone()) - { - break; - } } for (ulong i = 0; i < batchSize; i++) @@ -424,13 +413,9 @@ public void TestTopKTopPSearch() generator.AppendTokenSequences(sequences); - while (true) + while (!generator.IsDone()) { generator.GenerateNextToken(); - if (generator.IsDone()) - { - break; - } } for (ulong i = 0; i < batchSize; i++) @@ -635,13 +620,9 @@ public void TestPhi2() generator.AppendTokenSequences(sequences); - while (true) + while (!generator.IsDone()) { generator.GenerateNextToken(); - if (generator.IsDone()) - { - break; - } } for (ulong i = 0; i < batchSize; i++) @@ -745,13 +726,9 @@ public void TestAdapters() using var generator = new Generator(model, genParams); generator.AppendTokenSequences(sequences); - while (true) + while (!generator.IsDone()) { generator.GenerateNextToken(); - if (generator.IsDone()) - { - break; - } } using var logits = generator.GetOutput("logits"); @@ -777,13 +754,9 @@ public void TestAdapters() using var generator = new Generator(model, genParams); generator.SetActiveAdapter(adapters, "adapters_a_and_b"); generator.AppendTokenSequences(sequences); - while (true) + while (!generator.IsDone()) { generator.GenerateNextToken(); - if (generator.IsDone()) - { - break; - } } using var logits = generator.GetOutput("logits"); if (_useCudaModel) diff --git a/test/model_tests.cpp b/test/model_tests.cpp index 70dd8c2ea4..25bb1b8b96 100644 --- a/test/model_tests.cpp +++ b/test/model_tests.cpp @@ -84,11 +84,8 @@ TEST(ModelTests, GreedySearchGptFp32) { auto generator = OgaGenerator::Create(*model, *params); generator->AppendTokens(input_ids); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { - break; - } } // Verify outputs match expected outputs @@ -128,11 +125,8 @@ TEST(ModelTests, BeamSearchGptFp32) { auto generator = OgaGenerator::Create(*model, *params); generator->AppendTokens(input_ids); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { - break; - } } // Verify outputs match expected outputs @@ -165,11 +159,8 @@ void Test_GreedySearch_Gpt_Cuda(const char* model_path, const char* model_label) auto generator = OgaGenerator::Create(*model, *params); generator->AppendTokens(input_ids); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { - break; - } } // Verify outputs match expected outputs @@ -190,11 +181,8 @@ void Test_GreedySearch_Gpt_Cuda(const char* model_path, const char* model_label) generator = OgaGenerator::Create(*model, *params); generator->AppendTokens(input_ids); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { - break; - } } // Verify outputs match expected outputs @@ -205,11 +193,8 @@ void Test_GreedySearch_Gpt_Cuda(const char* model_path, const char* model_label) generator->RewindTo(3); std::vector next_ids{731, 731}; generator->AppendTokens(next_ids); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { - break; - } } // Verify outputs match expected outputs @@ -251,11 +236,8 @@ void Test_BeamSearch_Gpt_Cuda(const char* model_path, const char* model_label) { auto generator = OgaGenerator::Create(*model, *params); generator->AppendTokens(input_ids); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { - break; - } } // Verify outputs match expected outputs @@ -306,11 +288,8 @@ void Test_GreedySearch_Phi3_NvTensorRtRtx(const char* model_path, const char* mo auto generator = OgaGenerator::Create(*model, *params); generator->AppendTokens(input_ids); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { - break; - } } // Verify outputs match expected outputs @@ -362,11 +341,8 @@ void Test_OutOfPlaceKvCache_Phi3_NvTensorRtRtx(const char* model_path, const cha auto generator = OgaGenerator::Create(*model, *params); generator->AppendTokens(input_ids); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { - break; - } } auto sequence = generator->GetSequence(0); @@ -404,11 +380,8 @@ Print all primes between 1 and n auto generator = OgaGenerator::Create(*model, *params); generator->AppendTokens(tokens->Get(0)); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { - break; - } } auto result = generator->GetSequence(0); @@ -438,11 +411,8 @@ Print all primes between 1 and n auto generator = OgaGenerator::Create(*model, *params); generator->AppendTokens(tokens->Get(0)); - while (true) { + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { - break; - } } auto result = generator->GetSequence(0); diff --git a/test/platform/apple/apple_package_test/ios_package_testUITests/ios_package_uitest_cpp_api.mm b/test/platform/apple/apple_package_test/ios_package_testUITests/ios_package_uitest_cpp_api.mm index 9ff3bee78c..25ed5e8172 100644 --- a/test/platform/apple/apple_package_test/ios_package_testUITests/ios_package_uitest_cpp_api.mm +++ b/test/platform/apple/apple_package_test/ios_package_testUITests/ios_package_uitest_cpp_api.mm @@ -54,16 +54,18 @@ - (void)testCppAPI_Basic { auto generator = OgaGenerator::Create(*model, *params); generator->AppendTokenSequences(*sequences); - while (true) { + XCTAssertEqual(static_cast(params->GetSearchNumber("max_length")), 100); + XCTAssertEqual(params->GetSearchBool("early_stopping"), true); + XCTAssertEqual(static_cast(generator->TokenCount()), static_cast(generator->GetSequenceCount(0))); + + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { - break; - } } const auto output_sequence_length = generator->GetSequenceCount(0); const auto* output_sequence_data = generator->GetSequenceData(0); auto out_string = tokenizer->Decode(output_sequence_data, output_sequence_length); + XCTAssertEqual(static_cast(generator->TokenCount()), static_cast(generator->GetSequenceCount(0))); } @end diff --git a/test/platform/apple/apple_package_test/macos_package_testUITests/macos_package_uitest_cpp_api.mm b/test/platform/apple/apple_package_test/macos_package_testUITests/macos_package_uitest_cpp_api.mm index a4e5fdfc63..538289c322 100644 --- a/test/platform/apple/apple_package_test/macos_package_testUITests/macos_package_uitest_cpp_api.mm +++ b/test/platform/apple/apple_package_test/macos_package_testUITests/macos_package_uitest_cpp_api.mm @@ -54,16 +54,18 @@ - (void)testCppAPI_Basic { auto generator = OgaGenerator::Create(*model, *params); generator->AppendTokenSequences(*sequences); - while (true) { + XCTAssertEqual(static_cast(params->GetSearchNumber("max_length")), 100); + XCTAssertEqual(params->GetSearchBool("early_stopping"), true); + XCTAssertEqual(static_cast(generator->TokenCount()), static_cast(generator->GetSequenceCount(0))); + + while (!generator->IsDone()) { generator->GenerateNextToken(); - if (generator->IsDone()) { - break; - } } const auto output_sequence_length = generator->GetSequenceCount(0); const auto* output_sequence_data = generator->GetSequenceData(0); auto out_string = tokenizer->Decode(output_sequence_data, output_sequence_length); + XCTAssertEqual(static_cast(generator->TokenCount()), static_cast(generator->GetSequenceCount(0))); } @end diff --git a/test/python/requirements.txt b/test/python/requirements.txt index 3a851be2e5..2d7c639bdf 100644 --- a/test/python/requirements.txt +++ b/test/python/requirements.txt @@ -7,5 +7,5 @@ sympy pytest onnx onnx_ir>=0.1.3 -transformers +transformers<5.0.0 huggingface_hub[cli] diff --git a/test/python/test_onnxruntime_genai_api.py b/test/python/test_onnxruntime_genai_api.py index cd6dd43c2d..837bd4726a 100644 --- a/test/python/test_onnxruntime_genai_api.py +++ b/test/python/test_onnxruntime_genai_api.py @@ -148,15 +148,18 @@ def test_greedy_search(test_data_path, relative_model_path): generator = og.Generator(model, search_params) generator.append_tokens(np.array([[0, 0, 0, 52], [0, 0, 195, 731]], dtype=np.int32)) - while True: + + assert int(search_params.get_search_options()["max_length"]) == 10 + assert search_params.get_search_options()["early_stopping"] == True + assert int(generator.token_count()) == 4 + + while not generator.is_done(): # Test getting/setting logits logits = generator.get_logits() generator.set_logits(logits) generator.set_logits(logits) # twice just to be sure buffer is still valid generator.generate_next_token() - if generator.is_done(): - break expected_sequence = np.array( [ @@ -167,7 +170,7 @@ def test_greedy_search(test_data_path, relative_model_path): ) for i in range(batch_size): assert np.array_equal(expected_sequence[i], generator.get_sequence(i)) - + assert int(generator.token_count()) == len(generator.get_sequence(0)) @pytest.mark.parametrize( "relative_model_path", @@ -193,20 +196,16 @@ def test_rewind_cuda(test_data_path, relative_model_path): generator = og.Generator(model, search_params) generator.append_tokens(np.array([[0, 0, 195, 731]], dtype=np.int32)) - while True: + while not generator.is_done(): generator.generate_next_token() - if generator.is_done(): - break assert generator.get_sequence(0) is not None generator.rewind_to(3) generator.append_tokens(np.array([[731, 731]], dtype=np.int32)) - while True: + while not generator.is_done(): generator.generate_next_token() - if generator.is_done(): - break assert generator.get_sequence(0) is not None @@ -217,10 +216,8 @@ def test_rewind_cuda(test_data_path, relative_model_path): generator = og.Generator(model, search_params) generator.append_tokens(np.array([[0, 0, 0, 52], [0, 0, 195, 731], [64, 65, 66, 67]], dtype=np.int32)) - while True: + while not generator.is_done(): generator.generate_next_token() - if generator.is_done(): - break for i in range(batch_size): assert generator.get_sequence(i) is not None @@ -233,10 +230,8 @@ def test_rewind_cuda(test_data_path, relative_model_path): dtype=np.int32, ) ) - while True: + while not generator.is_done(): generator.generate_next_token() - if generator.is_done(): - break for i in range(batch_size): assert generator.get_sequence(i) is not None @@ -262,20 +257,16 @@ def test_rewind(test_data_path, relative_model_path): generator = og.Generator(model, search_params) generator.append_tokens(np.array([[0, 0, 195, 731]], dtype=np.int32)) - while True: + while not generator.is_done(): generator.generate_next_token() - if generator.is_done(): - break assert np.array_equal(expected_sequence, generator.get_sequence(0)) generator.rewind_to(3) generator.append_tokens(np.array([[731, 731]], dtype=np.int32)) - while True: + while not generator.is_done(): generator.generate_next_token() - if generator.is_done(): - break assert np.array_equal(expected_sequence, generator.get_sequence(0)) @@ -403,10 +394,8 @@ def test_batching(device, phi2_for): generator = og.Generator(model, params) generator.append_tokens(tokenizer.encode_batch(prompts)) - while True: + while not generator.is_done(): generator.generate_next_token() - if generator.is_done(): - break for i in range(len(prompts)): print(tokenizer.decode(generator.get_sequence(0))) @@ -434,10 +423,8 @@ def test_e2e(device, phi2_for): generator = og.Generator(model, params) generator.append_tokens(tokenizer.encode_batch(prompts)) - while True: + while not generator.is_done(): generator.generate_next_token() - if generator.is_done(): - break for i in range(len(prompts)): print(tokenizer.decode(generator.get_sequence(0))) @@ -469,10 +456,8 @@ def test_load_model_from_memory(device, wrapper_bytes_function, phi2_for): generator = og.Generator(model, params) generator.append_tokens(tokenizer.encode_batch(prompts)) - while True: + while not generator.is_done(): generator.generate_next_token() - if generator.is_done(): - break for i in range(len(prompts)): print(tokenizer.decode(generator.get_sequence(0))) @@ -648,10 +633,8 @@ def _split(onnx_model_path: os.PathLike, output_dir: os.PathLike): generator = og.Generator(model, params) generator.append_tokens(tokenizer.encode_batch(prompts)) - while True: + while not generator.is_done(): generator.generate_next_token() - if generator.is_done(): - break expected_output = [ "This is a test.\n # TOD import * doct proofingrad", @@ -848,10 +831,8 @@ def _export_adapter(adapter, adapter_file_name): generator.set_active_adapter(adapters, f"adapter_{i}") generator.append_tokens(tokenizer.encode_batch(prompts)) - while True: + while not generator.is_done(): generator.generate_next_token() - if generator.is_done(): - break @pytest.mark.parametrize("device", devices) @@ -935,10 +916,8 @@ def _prepare_model(test_data_path): else: generator.append_tokens(tokenizer.encode_batch(prompts)) - while True: + while not generator.is_done(): generator.generate_next_token() - if generator.is_done(): - break @pytest.mark.parametrize("relative_model_path", [Path("audio-preprocessing")]) diff --git a/test/python/test_onnxruntime_genai_e2e.py b/test/python/test_onnxruntime_genai_e2e.py index 643882ae5c..f28d674e21 100644 --- a/test/python/test_onnxruntime_genai_e2e.py +++ b/test/python/test_onnxruntime_genai_e2e.py @@ -31,10 +31,8 @@ def run_model(model_path: str | bytes | os.PathLike): generator = og.Generator(model, params) generator.append_tokens(sequences) - while True: + while not generator.is_done(): generator.generate_next_token() - if generator.is_done(): - break for i in range(3): assert generator.get_sequence(i) is not None