Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
7df4aa7
Add IsMaxLength API
kunal-vaishnavi Dec 18, 2025
9d41ed9
Add HitEOS and HitMaxLength APIs
kunal-vaishnavi Dec 18, 2025
b12e260
Add language bindings and update unit tests
kunal-vaishnavi Dec 18, 2025
795719d
Update README
kunal-vaishnavi Dec 18, 2025
4a13b80
Fix Java build and initialize pointers
kunal-vaishnavi Dec 18, 2025
3165142
Add checks for beam search
kunal-vaishnavi Dec 20, 2025
a3299aa
Merge branch 'main' into kvaishnavi/fix-early-return
kunal-vaishnavi Jan 21, 2026
74c4fac
Remove HitEOS from examples since GetNextTokens works in streaming mode
kunal-vaishnavi Jan 21, 2026
5cf30ae
Add C++ API for GetNextTokens and use in examples
kunal-vaishnavi Jan 21, 2026
f96fa85
Introduce TokenCount API instead
kunal-vaishnavi Jan 21, 2026
06f6703
Add GetSearchNumber and GetSearchBool APIs
kunal-vaishnavi Jan 22, 2026
d18fe38
Undo accidental change
kunal-vaishnavi Jan 22, 2026
93b6279
Fix return types in Java bindings
kunal-vaishnavi Jan 22, 2026
87ef255
Add missing value call
kunal-vaishnavi Jan 22, 2026
d5c51a1
Update return type for Objective-C binding of TokenCount
kunal-vaishnavi Jan 22, 2026
deac1b1
Add missing return in Java binding of TokenCount
kunal-vaishnavi Jan 22, 2026
606b494
Fix names of APIs called in C API
kunal-vaishnavi Jan 22, 2026
279eaa2
Add missing const references
kunal-vaishnavi Jan 22, 2026
67487c2
Add changes suggested by C++ linter
kunal-vaishnavi Jan 22, 2026
09fc909
Add some more missing const references
kunal-vaishnavi Jan 22, 2026
8d9fda8
Change how Python binding is done
kunal-vaishnavi Jan 22, 2026
a95ac69
Use fullname for pybind dict
kunal-vaishnavi Jan 22, 2026
d70baa1
Define TokenCount binding with PyGenerator instead of OgaGenerator
kunal-vaishnavi Jan 22, 2026
d160ea0
Add changes suggested by C++ linter
kunal-vaishnavi Jan 22, 2026
36d6e3d
Move ApplyChatTemplate into one line and ignore local C++ linter
kunal-vaishnavi Jan 22, 2026
ebe9eaf
Add assertions in unit tests for new APIs
kunal-vaishnavi Jan 22, 2026
e4bc169
Fix language binding API names
kunal-vaishnavi Jan 22, 2026
5ca18c7
Update how chunk size is obtained
kunal-vaishnavi Jan 22, 2026
6f6c0b0
Fix max length in assertion
kunal-vaishnavi Jan 22, 2026
1c2ed9d
Remove default values from Java bindings
kunal-vaishnavi Jan 22, 2026
fb1f8d5
Fix C API call inside Java API
kunal-vaishnavi Jan 22, 2026
2197b1a
Fix value in assert
kunal-vaishnavi Jan 22, 2026
735b73a
Merge branch 'main' into kvaishnavi/fix-early-return
kunal-vaishnavi Jan 22, 2026
7afeb51
Remove breaking changes documentation from README
kunal-vaishnavi Jan 22, 2026
3b990df
Construct vector in return statement
kunal-vaishnavi Jan 24, 2026
27d42b5
Make changes based on PR feedback
kunal-vaishnavi Jan 26, 2026
4461a6e
Add back missing definition
kunal-vaishnavi Jan 26, 2026
6913f9f
Pin transformers to be before v5
kunal-vaishnavi Jan 26, 2026
5a668f3
Cast token count from size_t to int in C#
kunal-vaishnavi Jan 27, 2026
c2a45e9
Simplify getting chunk size value
kunal-vaishnavi Jan 27, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 8 additions & 29 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ See documentation at the [ONNX Runtime website](https://onnxruntime.ai/docs/gena

| Support matrix | Supported now | Under development | On the roadmap|
| -------------- | ------------- | ----------------- | -------------- |
| Model architectures | ChatGLM</br>DeepSeek</br>Ernie</br>Fara</br>Gemma</br>GPTOSS</br>Granite</br>Llama</br>Mistral</br>Nemotron</br>OLMo</br>Phi</br>Phi3V</br>Phi4MM</br>Qwen</br>Qwen-2.5VL</br>SmolLM3</br>Whisper</br>| Stable diffusion ||
| Model architectures | AMD OLMo <br/> ChatGLM <br/> DeepSeek <br/> ERNIE 4.5 <br/> Fara <br/> Gemma <br/> gpt-oss <br/> Granite <br/> Llama <br/> Mistral <br/> Nemotron <br/> Phi (language + vision) <br/> Qwen (language + vision) <br/> SmolLM3 <br/> Whisper | Stable diffusion | Multi-modal models |
| API| Python <br/>C# <br/>C/C++ <br/> Java ^ | Objective-C ||
| O/S | Linux <br/> Windows <br/>Mac <br/>Android || iOS |||
| Architecture | x86 <br/> x64 <br/> arm64 ||||
Expand Down Expand Up @@ -78,10 +78,8 @@ See [installation instructions](https://onnxruntime.ai/docs/genai/howto/install)

try:
generator.append_tokens(input_tokens)
while True:
while not generator.is_done():
generator.generate_next_token()
if generator.is_done():
break
new_token = generator.get_next_tokens()[0]
print(tokenizer_stream.decode(new_token), end='', flush=True)
except KeyboardInterrupt:
Expand Down Expand Up @@ -115,13 +113,13 @@ Windows:
pip list | findstr "onnxruntime-genai"
```

Checkout the version of the examples that correspond to that release.
Then, check out the version of the examples that corresponds to that release.

```bash
# Clone the repo
git clone https://github.com/microsoft/onnxruntime-genai.git && cd onnxruntime-genai
# Checkout the branch for the version you are using
git checkout v0.11.4
git checkout v0.11.5
cd examples
```

Expand All @@ -145,30 +143,11 @@ Navigate to the examples folder in the main branch.
cd examples
```

## Breaking API changes
To install the nightly Python build:

### v0.11.0

Between `v0.11.0` and `v0.10.1`, there is a breaking API usage change to improve model quality during multi-turn conversations.

Previously, the decoding loop could be written as follows.

```
while not IsDone():
GenerateToken()
GetLastToken()
PrintLastToken()
```

In 0.11.0, the decoding loop should now be written as follows.

```
while True:
GenerateToken()
if IsDone():
break
GetLastToken()
PrintLastToken()
```bash
# Change onnxruntime-genai to the Python package you want to install
pip install --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/ onnxruntime-genai
```

## Roadmap
Expand Down
9 changes: 2 additions & 7 deletions examples/c/src/model_chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,20 +91,15 @@ void CXX_API(const char* model_path, const char* execution_provider) {
const auto current_token_count = generator->GetSequenceCount(0);

try {
while (true) {
while (!generator->IsDone()) {
generator->GenerateNextToken();

if (is_first_token) {
timing.RecordFirstTokenTimestamp();
is_first_token = false;
}

if (generator->IsDone()) {
break;
}

const auto num_tokens = generator->GetSequenceCount(0);
const auto new_token = generator->GetSequenceData(0)[num_tokens - 1];
const auto new_token = generator->GetNextTokens()[0];
std::cout << tokenizer_stream->Decode(new_token) << std::flush;
}
} catch (const std::exception& e) {
Expand Down
9 changes: 2 additions & 7 deletions examples/c/src/model_qa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,20 +82,15 @@ void CXX_API(const char* model_path, const char* execution_provider) {
generator->AppendTokenSequences(*sequences);

try {
while (true) {
while (!generator->IsDone()) {
generator->GenerateNextToken();

if (is_first_token) {
timing.RecordFirstTokenTimestamp();
is_first_token = false;
}

if (generator->IsDone()) {
break;
}

const auto num_tokens = generator->GetSequenceCount(0);
const auto new_token = generator->GetSequenceData(0)[num_tokens - 1];
const auto new_token = generator->GetNextTokens()[0];
std::cout << tokenizer_stream->Decode(new_token) << std::flush;
}
} catch (const std::exception& e) {
Expand Down
10 changes: 2 additions & 8 deletions examples/c/src/model_vision.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,9 @@ void CXX_API(const char* model_path, const char* execution_provider) {
auto generator = OgaGenerator::Create(*model, *params);
generator->SetInputs(*input_tensors);

while (true) {
while (!generator->IsDone()) {
generator->GenerateNextToken();

if (generator->IsDone()) {
break;
}

const auto num_tokens = generator->GetSequenceCount(0);
const auto new_token = generator->GetSequenceData(0)[num_tokens - 1];
const auto new_token = generator->GetNextTokens()[0];
std::cout << stream->Decode(new_token) << std::flush;
}

Expand Down
10 changes: 2 additions & 8 deletions examples/c/src/phi4-mm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,9 @@ void CXX_API(const char* model_path, const char* execution_provider) {
auto generator = OgaGenerator::Create(*model, *params);
generator->SetInputs(*input_tensors);

while (true) {
while (!generator->IsDone()) {
generator->GenerateNextToken();

if (generator->IsDone()) {
break;
}

const auto num_tokens = generator->GetSequenceCount(0);
const auto new_token = generator->GetSequenceData(0)[num_tokens - 1];
const auto new_token = generator->GetNextTokens()[0];
std::cout << stream->Decode(new_token) << std::flush;
}

Expand Down
5 changes: 1 addition & 4 deletions examples/c/src/whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,8 @@ void CXX_API(const char* model_path, int32_t num_beams) {
auto generator = OgaGenerator::Create(*model, *params);
generator->SetInputs(*inputs);

while (true) {
while (!generator->IsDone()) {
generator->GenerateNextToken();
if (generator->IsDone()) {
break;
}
}

for (size_t i = 0; i < static_cast<size_t>(num_beams * batch_size); ++i) {
Expand Down
18 changes: 3 additions & 15 deletions examples/csharp/HelloPhi/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,9 @@ static string GetPrompt(bool interactive)
using var generator = new Generator(model, generatorParams);
generator.AppendTokenSequences(sequences);
var watch = System.Diagnostics.Stopwatch.StartNew();
while (true)
while (!generator.IsDone())
{
generator.GenerateNextToken();
if (generator.IsDone())
{
break;
}
}

var outputSequence = generator.GetSequence(0);
Expand All @@ -155,13 +151,9 @@ static string GetPrompt(bool interactive)
using var generator = new Generator(model, generatorParams);
generator.AppendTokenSequences(sequences);
var watch = System.Diagnostics.Stopwatch.StartNew();
while (true)
while (!generator.IsDone())
{
generator.GenerateNextToken();
if (generator.IsDone())
{
break;
}
Console.Write(tokenizerStream.Decode(generator.GetNextTokens()[0]));
}
Console.WriteLine();
Expand Down Expand Up @@ -196,13 +188,9 @@ static string GetPrompt(bool interactive)
var sequences = tokenizer.Encode(tokenizer.ApplyChatTemplate("", messages, "", true));
var watch = System.Diagnostics.Stopwatch.StartNew();
generator.AppendTokenSequences(sequences);
while (true)
while (!generator.IsDone())
{
generator.GenerateNextToken();
if (generator.IsDone())
{
break;
}
Console.Write(tokenizerStream.Decode(generator.GetNextTokens()[0]));
}
Console.WriteLine();
Expand Down
6 changes: 1 addition & 5 deletions examples/csharp/HelloPhi3V/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,9 @@ void PrintUsage()
using var generator = new Generator(model, generatorParams);
generator.SetInputs(inputTensors);
var watch = System.Diagnostics.Stopwatch.StartNew();
while (true)
while (!generator.IsDone())
{
generator.GenerateNextToken();
if (generator.IsDone())
{
break;
}
Console.Write(stream.Decode(generator.GetNextTokens()[0]));
}
watch.Stop();
Expand Down
6 changes: 1 addition & 5 deletions examples/csharp/HelloPhi4MM/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -215,13 +215,9 @@ void PrintUsage()
using var generator = new Generator(model, generatorParams);
generator.SetInputs(inputTensors);
var watch = System.Diagnostics.Stopwatch.StartNew();
while (true)
while (!generator.IsDone())
{
generator.GenerateNextToken();
if (generator.IsDone())
{
break;
}
Console.Write(stream.Decode(generator.GetNextTokens()[0]));
}
watch.Stop();
Expand Down
5 changes: 1 addition & 4 deletions examples/python/awq-quantized-model.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,8 @@ def run_model(args):
print("Output: ", end="", flush=True)

try:
while True:
while not generator.is_done():
generator.generate_next_token()
if generator.is_done():
break

new_token = generator.get_next_tokens()[0]
print(tokenizer_stream.decode(new_token), end="", flush=True)
except KeyboardInterrupt:
Expand Down
5 changes: 1 addition & 4 deletions examples/python/model-chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,16 +216,13 @@ def main(args):
print("Output: ", end="", flush=True)

try:
while True:
while not generator.is_done():
generator.generate_next_token()
if args.timings:
if first:
first_token_timestamp = time.time()
first = False

if generator.is_done():
break

new_token = generator.get_next_tokens()[0]
print(tokenizer_stream.decode(new_token), end="", flush=True)
if args.timings:
Expand Down
4 changes: 1 addition & 3 deletions examples/python/model-generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,8 @@ def main(args):
if args.verbose:
print("Generating tokens ...\n")
start_time = time.time()
while True:
while not generator.is_done():
generator.generate_next_token()
if generator.is_done():
break
run_time = time.time() - start_time

for i in range(len(prompts)):
Expand Down
5 changes: 1 addition & 4 deletions examples/python/model-qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,16 +207,13 @@ def main(args):
print("Output: ", end="", flush=True)

try:
while True:
while not generator.is_done():
generator.generate_next_token()
if args.timings:
if first:
first_token_timestamp = time.time()
first = False

if generator.is_done():
break

new_token = generator.get_next_tokens()[0]
print(tokenizer_stream.decode(new_token), end="", flush=True)
if args.timings:
Expand Down
5 changes: 1 addition & 4 deletions examples/python/model-vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,8 @@ def run(args: argparse.Namespace):
generator.set_inputs(inputs)
start_time = time.time()

while True:
while not generator.is_done():
generator.generate_next_token()
if generator.is_done():
break

new_token = generator.get_next_tokens()[0]
print(stream.decode(new_token), end="", flush=True)

Expand Down
5 changes: 1 addition & 4 deletions examples/python/phi3-qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,16 +75,13 @@ def main(args):
print("Output: ", end="", flush=True)

try:
while True:
while not generator.is_done():
generator.generate_next_token()
if args.timings:
if first:
first_token_timestamp = time.time()
first = False

if generator.is_done():
break

new_token = generator.get_next_tokens()[0]
print(tokenizer_stream.decode(new_token), end="", flush=True)
if args.timings:
Expand Down
5 changes: 1 addition & 4 deletions examples/python/phi4-mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,8 @@ def run(args: argparse.Namespace):
generator.set_inputs(inputs)
start_time = time.time()

while True:
while not generator.is_done():
generator.generate_next_token()
if generator.is_done():
break

new_token = generator.get_next_tokens()[0]
print(tokenizer_stream.decode(new_token), end="", flush=True)

Expand Down
2 changes: 1 addition & 1 deletion src/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ struct Config {
int top_k{50}; // Number of highest probability vocabulary tokens to keep for top-k-filtering that will be used by default in the generate method of the model.
float top_p{}; // If set to float >0 and <1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.
float temperature{1.0f}; // Temperature to control during generation. Default is 1.0.
bool early_stopping{true}; // Whether to stop the beam search when at least num_beams sentences are finished per batch or not.
bool early_stopping{true}; // Whether to stop the beam search when at least num_beams sentences are finished per batch or not.
int no_repeat_ngram_size{}; // Unused param
float diversity_penalty{}; // Unused param
float length_penalty{1.0f}; // Exponential penalty to the length that is used with beam-based generation. length_penalty > 0.0 promotes longer sequences, while length_penalty < 0.0 encourages shorter sequences.
Expand Down
9 changes: 9 additions & 0 deletions src/csharp/Generator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,15 @@ public void AppendTokenSequences(Sequences sequences)
Result.VerifySuccess(NativeMethods.OgaGenerator_AppendTokenSequences(_generatorHandle, sequences.Handle));
}

/// <summary>
/// Gets the number of tokens in the generator
/// </summary>
/// <returns>The token count</returns>
public ulong TokenCount()
{
return NativeMethods.OgaGenerator_TokenCount(_generatorHandle).ToUInt64();
}

public void GenerateNextToken()
{
Result.VerifySuccess(NativeMethods.OgaGenerator_GenerateNextToken(_generatorHandle));
Expand Down
15 changes: 11 additions & 4 deletions src/csharp/GeneratorParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,21 @@ public void SetSearchOption(string searchOption, bool value)
Result.VerifySuccess(NativeMethods.OgaGeneratorParamsSetSearchBool(_generatorParamsHandle, StringUtils.ToUtf8(searchOption), value));
}

public void TryGraphCaptureWithMaxBatchSize(int maxBatchSize)
public void SetGuidance(string type, string data, bool enableFFTokens = false)
{
Console.WriteLine("TryGraphCaptureWithMaxBatchSize is deprecated and will be removed in a future release.");
Result.VerifySuccess(NativeMethods.OgaGeneratorParamsSetGuidance(_generatorParamsHandle, StringUtils.ToUtf8(type), StringUtils.ToUtf8(data), enableFFTokens));
Comment thread
kunal-vaishnavi marked this conversation as resolved.
}

public void SetGuidance(string type, string data, bool enableFFTokens = false)
public double GetSearchNumber(string searchOption)
{
Result.VerifySuccess(NativeMethods.OgaGeneratorParamsSetGuidance(_generatorParamsHandle, StringUtils.ToUtf8(type), StringUtils.ToUtf8(data), enableFFTokens));
Result.VerifySuccess(NativeMethods.OgaGeneratorParamsGetSearchNumber(_generatorParamsHandle, StringUtils.ToUtf8(searchOption), out double value));
Comment thread
kunal-vaishnavi marked this conversation as resolved.
return value;
}

public bool GetSearchBool(string searchOption)
{
Result.VerifySuccess(NativeMethods.OgaGeneratorParamsGetSearchBool(_generatorParamsHandle, StringUtils.ToUtf8(searchOption), out bool value));
Comment thread
kunal-vaishnavi marked this conversation as resolved.
return value;
}

~GeneratorParams()
Expand Down
Loading
Loading