Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/deps.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pybind11;https://github.com/pybind/pybind11/archive/refs/tags/v2.13.6.zip;f78029
googletest;https://github.com/google/googletest/archive/530d5c8c84abd2a46f38583ee817743c9b3a42b4.zip;5e3a61db2aa975cfd0f97ba92c818744e7fa7034
microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5
directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e
onnxruntime_extensions;https://github.com/microsoft/onnxruntime-extensions.git;245f6667babf9668b862ac4513c69ea95117c295
onnxruntime_extensions;https://github.com/microsoft/onnxruntime-extensions.git;9f1f67d6d075793a0828b24e73d50803eb657e9a

# These two dependencies are for the optional constrained decoding feature (USE_GUIDANCE)
llguidance;https://github.com/microsoft/llguidance.git;94fa39128ef184ffeda33845f6d333f332a34b4d
Expand Down
51 changes: 46 additions & 5 deletions examples/python/model-vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,28 @@
import glob
import json
import os
import readline
import time
from pathlib import Path

import onnxruntime_genai as og

# og.set_log_options(enabled=True, model_input_values=True, model_output_values=True)

# Tool-calling system prompt for Qwen/Fara models
FARA_SYSTEM_PROMPT = """You are a web agent trying to complete user tasks on websites using function calls.

The functions at your disposal are:
<tools>
{"type": "function", "function": {"name": "computer_use", "description": "Use a mouse and keyboard to interact with a computer based on screenshots.\\n- This is an interface to a web browser. You do not have access to a terminal or applications menu, only the browser.\\n- Some pages, etc. may take time to start or process actions, so you may need to wait and take successive screenshots to see the results of your actions. E.g. if you click a home page icon and a window doesn't change, try wait and taking another screenshot.\\n- Whenever you intend to move the cursor to click on an element like an icon, you should consult a screenshot to determine the coordinates of the element before moving the cursor.\\n- If you tried clicking on a program or link but it failed to load, even after waiting, try adjusting your cursor position so that the tip of the cursor visually falls on the element that you want to click.\\n- Make sure to click any buttons, links, icons, etc with the cursor tip in the center of the element. Don't click boxes on their edges unless asked.\\n- When a separate scrollable container prominently overlays the webpage, if you want to scroll within it, you typically need to mouse_move() over it first and then scroll().\\nScreen resolution: 1428x896", "parameters": {"properties": {"action": {"description": "The action to perform. The available actions are:\\n* `key`: Press keyboard keys, like \\"Enter\\", \\"Alt\\", \\"Shift\\", \\"Tab\\", \\"Control\\", \\"Backspace\\", \\"Delete\\", \\"Escape\\", etc. Keys are pressed down in the order given, then released in reverse order.\\n* `type`: Type a string of text on the keyboard.\\n* `mouse_move`: Move the cursor to a specified (x, y) pixel coordinate on the screen.\\n* `left_click`: Click the left mouse button.\\n* `scroll`: Performs a scroll of the mouse scroll wheel.\\n* `visit_url`: Visit a specified URL.\\n* `web_search`: Perform a web search with a specified query.\\n* `history_back`: Go back to the previous page in the browser history.\\n* `pause_and_memorize_fact`: Pause and memorize a fact for future reference.\\n* `wait`: Wait specified seconds for the change to happen.\\n* `terminate`: Terminate the current task and report its completion status.", "enum": ["key", "type", "mouse_move", "left_click", "scroll", "visit_url", "web_search", "history_back", "pause_and_memorize_fact", "wait", "terminate"], "type": "string"}, "keys": {"description": "Keyboard keys to be pressed in order. Required only by `action=key`.", "type": "array"}, "text": {"description": "Text to type. Required only by `action=type`.", "type": "string"}, "press_enter": {"description": "Whether to press the 'Enter' key after typing. Required only by `action=type`.", "type": "boolean"}, "delete_existing_text": {"description": "Whether to delete existing text before typing. Required only by `action=type`.", "type": "boolean"}, "coordinate": {"description": "[x, y]: The x (pixels from the left edge) and y (pixels from the top edge) coordinates to move the mouse to. Required only by `action=left_click`, `action=mouse_move`, and `action=type`.", "type": "array"}, "pixels": {"description": "The amount of scrolling to perform. Positive values scroll up, negative values scroll down. Required only by `action=scroll`.", "type": "number"}, "url": {"description": "The URL to visit. Required only by `action=visit_url`.", "type": "string"}, "query": {"description": "The query to search for. Required only by `action=web_search`.", "type": "string"}, "fact": {"description": "The fact to remember for the future. Required only by `action=pause_and_memorize_fact`.", "type": "string"}, "time": {"description": "Number of seconds to wait. Required only by `action=wait`.", "type": "number"}, "status": {"description": "The status of the task. Required only by `action=terminate`.", "type": "string", "enum": ["success", "failure"]}}, "required": ["action"], "type": "object"}}}
</tools>

To make a function call, you should output a json object inside <tool_call></tool_call> XML tags. The json object must contain the function name and its arguments, like this:
<tool_call>
{\\"name\\": <function-name>, \\"arguments\\": <args-json-object>}
</tool_call>
"""


def _find_dir_contains_sub_dir(current_dir: Path, target_dir_name):
curr_path = Path(current_dir).absolute()
Expand All @@ -26,10 +41,20 @@ def _find_dir_contains_sub_dir(current_dir: Path, target_dir_name):


def _complete(text, state):
return (glob.glob(text + "*") + [None])[state]
return [*glob.glob(text + "*"), None][state]


def run(args: argparse.Namespace):
if args.use_winml:
try:
import winml

print(winml.register_execution_providers(ort=False, ort_genai=True))
except ImportError:
print("WinML not available, using default execution providers")
except Exception as e:
print(f"Failed to register WinML execution providers: {e}")

print("Loading model...")
config = og.Config(args.model_path)
if args.execution_provider != "follow_config":
Expand All @@ -49,8 +74,6 @@ def run(args: argparse.Namespace):
while True:
if interactive:
try:
import readline

readline.set_completer_delims(" \t\n;")
readline.parse_and_bind("tab: complete")
readline.set_completer(_complete)
Expand Down Expand Up @@ -80,7 +103,7 @@ def run(args: argparse.Namespace):
if len(image_paths) == 0:
print("No image provided")
else:
for i, image_path in enumerate(image_paths):
for _, image_path in enumerate(image_paths):
if not os.path.exists(image_path):
raise FileNotFoundError(f"Image file not found: {image_path}")
print(f"Using image: {image_path}")
Expand All @@ -101,6 +124,10 @@ def run(args: argparse.Namespace):
# Combine all image tags and text into one user message
content = "".join([f"<|image_{i + 1}|>\n" for i in range(len(image_paths))]) + text
messages.append({"role": "user", "content": content})
elif model.type in ["qwen2_5_vl", "fara"]:
messages.append({"role": "system", "content": FARA_SYSTEM_PROMPT})
content = "".join(["<|vision_start|><|image_pad|><|vision_end|>" for _ in image_paths]) + text
messages.append({"role": "user", "content": content})
else:
# Gemma3-style multimodal: structured content
content_list = [{"type": "image"} for _ in image_paths]
Expand All @@ -116,7 +143,8 @@ def run(args: argparse.Namespace):

print("Generating response...")
params = og.GeneratorParams(model)
params.set_search_options(max_length=7680)
max_length = args.max_length if args.max_length else 7680
params.set_search_options(max_length=max_length)

generator = og.Generator(model, params)
generator.set_inputs(inputs)
Expand Down Expand Up @@ -162,11 +190,24 @@ def run(args: argparse.Namespace):
parser.add_argument(
"-pr", "--prompt", required=False, help="Input prompts to generate tokens from, mainly for CI usage"
)
parser.add_argument(
"--max_length",
type=int,
required=False,
default=None,
help="Maximum generation length. Defaults to model's context_length from config.",
)
parser.add_argument(
"--non-interactive",
action=argparse.BooleanOptionalAction,
required=False,
help="Non-interactive mode, mainly for CI usage",
)
parser.add_argument(
"--use-winml",
action="store_true",
required=False,
help="Register WinML execution providers before loading the model",
)
args = parser.parse_args()
run(args)
108 changes: 98 additions & 10 deletions src/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,11 @@ struct Decoder_Element : JSON::Element {
v_.sliding_window = Config::Model::Decoder::SlidingWindow{};
return sliding_window_;
}
// Support object-style pipeline: "pipeline": { "embeddings": { ... }, ... }
if (name == "pipeline") {
pipeline_object_ = std::make_unique<PipelineModelObject_Element>(v_.pipeline);
return *pipeline_object_;
}
throw JSON::unknown_value_error{};
}

Expand All @@ -605,6 +610,7 @@ struct Decoder_Element : JSON::Element {
DecoderOutputs_Element outputs_{v_.outputs};
Pipeline_Element pipeline_{v_.pipeline};
SlidingWindow_Element sliding_window_{v_.sliding_window};
std::unique_ptr<PipelineModelObject_Element> pipeline_object_; // object-style pipeline support
};

struct VisionInputs_Element : JSON::Element {
Expand All @@ -615,6 +621,8 @@ struct VisionInputs_Element : JSON::Element {
v_.pixel_values = JSON::Get<std::string_view>(value);
} else if (name == "image_sizes") {
v_.image_sizes = JSON::Get<std::string_view>(value);
} else if (name == "image_grid_thw") {
v_.image_grid_thw = JSON::Get<std::string_view>(value);
} else if (name == "attention_mask") {
v_.attention_mask = JSON::Get<std::string_view>(value);
} else {
Expand All @@ -641,6 +649,77 @@ struct VisionOutputs_Element : JSON::Element {
Config::Model::Vision::Outputs& v_;
};

// Vision pipeline support structures
struct VisionPipelineModel_Element : JSON::Element {
explicit VisionPipelineModel_Element(Config::Model::Vision::PipelineModel& v) : v_{v} {}

void OnValue(std::string_view name, JSON::Value value) override {
if (name == "filename") {
v_.filename = JSON::Get<std::string_view>(value);
} else if (name == "run_on_cpu") {
Comment thread
tianleiwu marked this conversation as resolved.
v_.run_on_cpu = JSON::Get<bool>(value);
} else {
throw JSON::unknown_value_error{};
}
}

Element& OnObject(std::string_view name) override {
if (name == "session_options") {
v_.session_options = Config::SessionOptions{};
session_options_ = std::make_unique<SessionOptions_Element>(*v_.session_options);
return *session_options_;
}
if (name == "run_options") {
v_.run_options = Config::RunOptions{};
run_options_ = std::make_unique<RunOptions_Element>(*v_.run_options);
return *run_options_;
}
throw JSON::unknown_value_error{};
}

Element& OnArray(std::string_view name) override {
if (name == "inputs") {
return inputs_;
}
if (name == "outputs") {
return outputs_;
}
throw JSON::unknown_value_error{};
}

private:
Config::Model::Vision::PipelineModel& v_;
std::unique_ptr<SessionOptions_Element> session_options_;
std::unique_ptr<RunOptions_Element> run_options_;
StringArray_Element inputs_{v_.inputs};
StringArray_Element outputs_{v_.outputs};
};

struct VisionPipelineModelObject_Element : JSON::Element {
explicit VisionPipelineModelObject_Element(std::vector<Config::Model::Vision::PipelineModel>& v) : v_{v} {}

Element& OnObject(std::string_view name) override {
auto& model = v_.emplace_back();
model.model_id = name;
elements_.emplace_back(model);
return elements_.back();
}

private:
std::vector<Config::Model::Vision::PipelineModel>& v_;
std::vector<VisionPipelineModel_Element> elements_;
};

struct VisionPipeline_Element : JSON::Element {
explicit VisionPipeline_Element(std::vector<Config::Model::Vision::PipelineModel>& v) : v_{v} {}

Element& OnObject(std::string_view name) override { return object_; }

private:
std::vector<Config::Model::Vision::PipelineModel>& v_;
VisionPipelineModelObject_Element object_{v_};
};

struct Vision_Element : JSON::Element {
explicit Vision_Element(Config::Model::Vision& v) : v_{v} {}

Expand Down Expand Up @@ -673,6 +752,18 @@ struct Vision_Element : JSON::Element {
if (name == "outputs") {
return outputs_;
}
// Support object-style pipeline for vision: "pipeline": { "patch_embed": { ... }, ... }
if (name == "pipeline") {
vision_pipeline_object_ = std::make_unique<VisionPipelineModelObject_Element>(v_.pipeline);
return *vision_pipeline_object_;
}
throw JSON::unknown_value_error{};
}

Element& OnArray(std::string_view name) override {
if (name == "pipeline") {
return pipeline_element_;
}
throw JSON::unknown_value_error{};
}

Expand All @@ -682,6 +773,8 @@ struct Vision_Element : JSON::Element {
std::unique_ptr<RunOptions_Element> run_options_;
VisionInputs_Element inputs_{v_.inputs};
VisionOutputs_Element outputs_{v_.outputs};
VisionPipeline_Element pipeline_element_{v_.pipeline};
std::unique_ptr<VisionPipelineModelObject_Element> vision_pipeline_object_; // object-style pipeline support
};

struct SpeechInputs_Element : JSON::Element {
Expand Down Expand Up @@ -1212,19 +1305,14 @@ void ClearDecoderProviderOptionsHardwareVendorId(Config& config, std::string_vie
struct Root_Element : JSON::Element {
explicit Root_Element(Config& config) : config_{config} {}

void OnValue(std::string_view name, JSON::Value value) override {
void OnValue(std::string_view /*name*/, JSON::Value /*value*/) override {
// No top-level scalar values currently supported
}

Element& OnObject(std::string_view name) override {
if (name == "model") {
return model_element_;
}
if (name == "search") {
return search_element_;
}
if (name == "engine") {
return engine_element_;
}
if (name == "model") return model_element_;
if (name == "search") return search_element_;
if (name == "engine") return engine_element_;
throw JSON::unknown_value_error{};
}

Expand Down
13 changes: 13 additions & 0 deletions src/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,22 @@ struct Config {
std::string config_filename{"processor_config.json"};
std::optional<std::string> adapter_filename{};

// Vision pipeline support (patch embed -> vision attn -> patch merger)
struct PipelineModel {
std::string filename;
std::optional<SessionOptions> session_options;
std::optional<RunOptions> run_options;
std::string model_id; // Identifier used to link outputs to subsequent stages
std::vector<std::string> inputs; // Graph input names
std::vector<std::string> outputs; // Graph output names
bool run_on_cpu{false}; // If true force CPU EP when multiple EPs are configured
};
std::vector<PipelineModel> pipeline; // Ordered pipeline models

struct Inputs {
std::string pixel_values{Defaults::PixelValuesName};
std::string image_sizes{Defaults::ImageSizesName};
std::string image_grid_thw{Defaults::ImageSizesName}; // Qwen2.5-VL uses image_grid_thw, defaults to image_sizes
std::string attention_mask{Defaults::ImageAttentionMaskName}; // image attention mask
} inputs;

Expand Down
24 changes: 17 additions & 7 deletions src/generators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,14 +318,24 @@ DeviceSpan<int32_t> Generator::AllocateInputIdsOnDevice(cpu_span<const int32_t>

auto input_ids_device = state_->params_->p_device->Allocate<int32_t>(padded_input_ids_size);
auto cpu_span = input_ids_device.CpuSpan();
auto padding_begin = cpu_span.begin();
auto data_end = cpu_span.end();
if (model_->config_->model.decoder.sliding_window.has_value() && model_->config_->model.decoder.sliding_window->alignment == "left") {
padding_begin = cpu_span.begin() + input_ids.size();
data_end = padding_begin;

// Handle padding based on alignment setting for sliding window models
if (padded_input_ids_size > input_ids.size()) {
const bool left_align = model_->config_->model.decoder.sliding_window.has_value() &&
model_->config_->model.decoder.sliding_window->alignment == "left";

if (left_align) {
// Left alignment: padding first, then data
std::fill_n(cpu_span.begin(), padded_input_ids_size - input_ids.size(), model_->config_->model.pad_token_id);
std::copy(input_ids.begin(), input_ids.end(), cpu_span.begin() + (padded_input_ids_size - input_ids.size()));
} else {
// Right alignment (default): data first, then padding
std::copy(input_ids.begin(), input_ids.end(), cpu_span.begin());
std::fill(cpu_span.begin() + input_ids.size(), cpu_span.end(), model_->config_->model.pad_token_id);
}
} else {
std::copy(input_ids.begin(), input_ids.end(), cpu_span.begin());
}
std::fill_n(padding_begin, padded_input_ids_size - input_ids.size(), model_->config_->model.pad_token_id);
std::copy_backward(input_ids.begin(), input_ids.end(), data_end);
input_ids_device.CopyCpuToDevice();
return input_ids_device;
}
Expand Down
30 changes: 25 additions & 5 deletions src/models/decoder_only.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ DecoderOnly_State::DecoderOnly_State(const DecoderOnly_Model& model, DeviceSpan<
: State{params, model},
model_{model},
kv_cache_(CreateKeyValueCache(*this)),
position_inputs_{model, *this, sequence_lengths_unk, model_.config_->model.decoder.inputs.attention_mask} {
position_inputs_{CreatePositionInputs(*this, sequence_lengths_unk, model_.config_->model.decoder.inputs.attention_mask)} {
input_ids_.Add();
position_inputs_.Add();
position_inputs_->Add();
logits_.Add();
kv_cache_->Add();
}
Expand Down Expand Up @@ -79,15 +79,35 @@ DeviceSpan<float> DecoderOnly_State::RunWithChunking(int total_length, DeviceSpa
}

void DecoderOnly_State::RewindTo(size_t index) {
position_inputs_.RewindTo(index);
position_inputs_->RewindTo(index);
kv_cache_->RewindTo(index);
}

void DecoderOnly_State::UpdateInputsOutputs(DeviceSpan<int32_t>& next_tokens, DeviceSpan<int32_t> beam_indices, int total_length) {
input_ids_.Update(next_tokens);
size_t new_length = static_cast<size_t>(input_ids_.GetShape()[1]);
position_inputs_.Update(next_tokens, total_length, static_cast<int>(new_length));
kv_cache_->Update(beam_indices, total_length);

// Determine effective lengths for position_ids and KV cache based on sliding window config
int position_length = total_length;
int kv_cache_length = total_length;

if (model_.config_->model.decoder.sliding_window.has_value() &&
model_.config_->model.decoder.sliding_window->window_size > 0) {
const int window_size = model_.config_->model.decoder.sliding_window->window_size;

// Position IDs are clamped when slide_inputs is true
if (model_.config_->model.decoder.sliding_window->slide_inputs) {
position_length = std::min(total_length, window_size);
}

// KV cache is clamped when slide_key_value_cache is true
if (model_.config_->model.decoder.sliding_window->slide_key_value_cache) {
kv_cache_length = std::min(total_length, window_size);
}
}

position_inputs_->Update(next_tokens, position_length, static_cast<int>(new_length));
kv_cache_->Update(beam_indices, kv_cache_length);
logits_.Update(next_tokens, new_length);
}

Expand Down
Loading
Loading