Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 14 additions & 11 deletions examples/python/model-qa.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import onnxruntime_genai as og
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License

import onnxruntime_genai as og
import argparse
import time
import numpy as np

def main(args):
if args.verbose: print("Loading model...")
Expand Down Expand Up @@ -45,14 +47,16 @@ def main(args):
raise ValueError("Chat template must have exactly one pair of curly braces with input word in it, e.g. '<|user|>\n{input} <|end|>\n<|assistant|>'")
else:
if model_type.startswith("phi4"):
args.chat_template = '<|im_start|>user<|im_sep|>\n{input}<|im_end|>\n<|im_start|>assistant<|im_sep|>'
args.chat_template = '{system_prompt}<|im_start|>user<|im_sep|>\n{input}<|im_end|>\n<|im_start|>assistant<|im_sep|>'
elif model_type.startswith("phi"): # For Phi2 and Phi3
args.chat_template = '<|user|>\n{input} <|end|>\n<|assistant|>'
args.chat_template = '{system_prompt}<|user|>\n{input} <|end|>\n<|assistant|>'
elif model_type.startswith("llama"):
args.chat_template = '<|start_header_id|>user<|end_header_id|>\n{input}<|eot_id|><|start_header_id|>assistant<|end_header_id|>'
args.chat_template = '{system_prompt}<|start_header_id|>user<|end_header_id|>\n{input}<|eot_id|><|start_header_id|>assistant<|end_header_id|>'
print("Using Chat Template for LLAMA 3, if you are using LLAMA 2 please pass the argument --chat_template '{input} [/INST]')")
elif model_type.startswith("qwen2"):
args.chat_template = '<|im_start|>user\n{input}<|im_end|>\n<|im_start|>assistant\n'
args.chat_template = '{system_prompt}<|im_start|>user\n{input}<|im_end|>\n<|im_start|>assistant\n'
elif model_type == "gemma3_text":
args.chat_template = '<start_of_turn>user\n{system_prompt}{input}<end_of_turn>\n<start_of_turn>model\n'
else:
raise ValueError(f"Chat Template for model type {model_type} is not known. Please provide chat template using --chat_template")

Expand All @@ -70,12 +74,11 @@ def main(args):
print("Using System Prompt for LLAMA 3, if you are using LLAMA 2 please pass the argument --system_prompt '<s>[INST] <<SYS>>\\n{args.system_prompt}\\n<</SYS>>')")
elif model_type.startswith("qwen2"):
system_prompt = f"<|im_start|>system\n{args.system_prompt}<|im_end|>\n"
elif model_type == "gemma3_text":
system_prompt = f"{args.system_prompt}"
else:
system_prompt = args.system_prompt

system_tokens = tokenizer.encode(system_prompt)
system_prompt_length = len(system_tokens)

# Keep asking for input prompts in a loop
while True:
if args.input_prompt:
Expand All @@ -91,7 +94,7 @@ def main(args):

if args.timings: started_timestamp = time.time()

prompt = f'{args.chat_template.format(input=text)}'
prompt = f'{args.chat_template.format(system_prompt=system_prompt, input=text)}'
input_tokens = tokenizer.encode(prompt)

params = og.GeneratorParams(model)
Expand All @@ -100,7 +103,7 @@ def main(args):
if args.verbose: print("Generator created")

# Append system and input tokens to the generator
generator.append_tokens(np.concatenate([system_tokens, input_tokens]))
generator.append_tokens(input_tokens)

if args.verbose: print("Running generation loop ...")
if args.timings:
Expand Down
19 changes: 15 additions & 4 deletions examples/python/phi3v.py → examples/python/model-vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import onnxruntime_genai as og


def _find_dir_contains_sub_dir(current_dir: Path, target_dir_name):
curr_path = Path(current_dir).absolute()
target_dir = glob.glob(target_dir_name, root_dir=curr_path)
Expand Down Expand Up @@ -65,16 +64,28 @@ def run(args: argparse.Namespace):

image_paths = [image_path for image_path in image_paths if image_path]

user_tag = image_tag = assistant_tag = end_tag = ""
if model.type == "phi3v":
user_tag = "<|user|>\n"
image_tag = "<|image_{image_id}|>\n"
assistant_tag = "<|assistant|>\n"
end_tag = "<|end|>\n"
elif model.type == "gemma3":
user_tag = "<start_of_turn>user\n"
image_tag = "<start_of_image>"
assistant_tag = "<start_of_turn>model\n"
end_tag = "<end_of_turn>\n"

images = None
prompt = "<|user|>\n"
prompt = f"{user_tag}"
if len(image_paths) == 0:
print("No image provided")
else:
for i, image_path in enumerate(image_paths):
if not os.path.exists(image_path):
raise FileNotFoundError(f"Image file not found: {image_path}")
print(f"Using image: {image_path}")
prompt += f"<|image_{i+1}|>\n"
prompt += f"{image_tag.replace('{image_id}', str(i+1))}"

images = og.Images.open(*image_paths)

Expand All @@ -85,7 +96,7 @@ def run(args: argparse.Namespace):
text = args.prompt
else:
text = "What is shown in this image?"
prompt += f"{text}<|end|>\n<|assistant|>\n"
prompt += f"{text}{end_tag}{assistant_tag}"
print("Processing images and prompt...")
inputs = processor(prompt, images=images)

Expand Down
1 change: 1 addition & 0 deletions src/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ struct Config {
static constexpr std::string_view PastSequenceLengthName = "past_sequence_length";
static constexpr std::string_view promptTemplate = "{Content}";
static constexpr std::string_view TotalSequenceLengthName = "total_sequence_length";
static constexpr std::string_view TokenTypeIdsName = "token_type_ids";

// Vision names
static constexpr std::string_view PixelValuesName = "pixel_values";
Expand Down
125 changes: 125 additions & 0 deletions src/models/gemma_image_processor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "../generators.h"
#include "model.h"

#include <regex>

namespace Generators {

namespace {

std::tuple<std::unique_ptr<OrtValue>, std::unique_ptr<OrtValue>, std::unique_ptr<OrtValue>>
ProcessImagePrompt(const Generators::Tokenizer& tokenizer, const std::string& prompt,
OrtxTensor* pixel_values, Ort::Allocator& allocator) {
constexpr char boi_token[] = "<start_of_image>";
constexpr char image_token[] = "<image_soft_token>";
constexpr char eoi_token[] = "<end_of_image>";
constexpr size_t image_seq_length = 256;

int64_t num_images{};
if (pixel_values) {
const float* pixel_values_data{};
const int64_t* pixel_values_shape{};
size_t pixel_values_num_dims;
CheckResult(OrtxGetTensorData(pixel_values, reinterpret_cast<const void**>(&pixel_values_data),
&pixel_values_shape, &pixel_values_num_dims));
num_images = pixel_values_shape[0];
}

// Generate input_ids and token_type_ids
std::string text = prompt;
if (text.empty()) {
for (int64_t i = 0; i < num_images; ++i) {
text += "<start_of_image> ";
}
text.pop_back();
}

// Count the number of boi tokens and make sure it matches the number of images
const std::regex boi_regex{std::string(boi_token)};
const auto boi_begin = std::sregex_iterator(text.begin(), text.end(), boi_regex);
const auto boi_end = std::sregex_iterator();
const auto boi_tokens = std::distance(boi_begin, boi_end);
if (num_images != boi_tokens) {
throw std::runtime_error("Prompt contained " + std::to_string(boi_tokens) + " image tokens but received " +
std::to_string(num_images) + " images.");
}

std::string image_tokens_expanded{};
for (size_t i = 0; i < image_seq_length; ++i) {
image_tokens_expanded += image_token;
}
const std::string full_image_sequence = std::string("\n\n") + boi_token + image_tokens_expanded + eoi_token + std::string("\n\n");

text = std::regex_replace(text, boi_regex, full_image_sequence);

const std::vector<int32_t> input_ids = tokenizer.Encode(text.c_str());

std::unique_ptr<OrtValue> input_ids_value = OrtValue::CreateTensor<int32_t>(allocator, std::vector<int64_t>{1, static_cast<int64_t>(input_ids.size())});
std::copy(input_ids.begin(), input_ids.end(), input_ids_value->GetTensorMutableData<int32_t>());

std::unique_ptr<OrtValue> token_type_ids = OrtValue::CreateTensor<int32_t>(allocator, std::vector<int64_t>{1, static_cast<int64_t>(input_ids.size())});
const auto image_token_id = tokenizer.TokenToTokenId(image_token);
for (size_t i = 0; i < input_ids.size(); ++i) {
if (input_ids[i] == image_token_id) {
token_type_ids->GetTensorMutableData<int32_t>()[i] = 1;
} else {
token_type_ids->GetTensorMutableData<int32_t>()[i] = 0;
}
}

std::unique_ptr<OrtValue> num_img_tokens = OrtValue::CreateTensor<int32_t>(allocator, std::vector<int64_t>{1});
num_img_tokens->GetTensorMutableData<int32_t>()[0] = static_cast<int32_t>(image_seq_length);

return {std::move(input_ids_value), std::move(token_type_ids), std::move(num_img_tokens)};
}

} // namespace

GemmaImageProcessor::GemmaImageProcessor(Config& config, const SessionInfo& session_info)
: pixel_values_type_{session_info.GetInputDataType(config.model.vision.inputs.pixel_values)} {
const auto processor_config = (config.config_path / fs::path(config.model.vision.config_filename)).string();
CheckResult(OrtxCreateProcessor(processor_.ToBeAssigned(), processor_config.c_str()));

config.AddMapping(std::string(Config::Defaults::InputIdsName), config.model.embedding.inputs.input_ids);
config.AddMapping(std::string(Config::Defaults::PixelValuesName), config.model.vision.inputs.pixel_values);
}

std::unique_ptr<NamedTensors> GemmaImageProcessor::Process(const Tokenizer& tokenizer, const Payload& payload) const {
std::string prompt = std::string(payload.prompt);
const Images* images = payload.images;
Ort::Allocator& allocator{Ort::Allocator::GetWithDefaultOptions()};
auto named_tensors = std::make_unique<NamedTensors>();

if (!images) {
[[maybe_unused]] auto [input_ids, token_type_ids, num_img_tokens] = ProcessImagePrompt(tokenizer, prompt, nullptr, allocator);
named_tensors->emplace(Config::Defaults::InputIdsName, std::make_shared<Tensor>(std::move(input_ids)));
return named_tensors;
}

ort_extensions::OrtxObjectPtr<OrtxTensorResult> result;
CheckResult(OrtxImagePreProcess(processor_.get(), images->images_.get(), result.ToBeAssigned()));

OrtxTensor* pixel_values = nullptr;
CheckResult(OrtxTensorResultGetAt(result.get(), 0, &pixel_values));

auto [input_ids, token_type_ids, num_img_tokens] = ProcessImagePrompt(tokenizer, prompt, pixel_values, allocator);
named_tensors->emplace(std::string(Config::Defaults::InputIdsName), std::make_shared<Tensor>(std::move(input_ids)));
named_tensors->emplace(std::string(Config::Defaults::TokenTypeIdsName), std::make_shared<Tensor>(std::move(token_type_ids)));

if (pixel_values_type_ == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
named_tensors->emplace(std::string(Config::Defaults::PixelValuesName),
std::make_shared<Tensor>(ProcessTensor<float>(pixel_values, allocator)));
} else {
named_tensors->emplace(std::string(Config::Defaults::PixelValuesName),
std::make_shared<Tensor>(ProcessTensor<Ort::Float16_t>(pixel_values, allocator)));
}

named_tensors->emplace(std::string(Config::Defaults::NumImageTokens), std::make_shared<Tensor>(std::move(num_img_tokens)));

return named_tensors;
}

} // namespace Generators
20 changes: 20 additions & 0 deletions src/models/gemma_image_processor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once

#include "processor.h"

namespace Generators {

struct GemmaImageProcessor : Processor {
GemmaImageProcessor(Config& config, const SessionInfo& session_info);

virtual std::unique_ptr<NamedTensors> Process(const Tokenizer& tokenizer, const Payload& payload) const override;

private:
ort_extensions::OrtxObjectPtr<OrtxProcessor> processor_;

ONNXTensorElementDataType pixel_values_type_;
};

} // namespace Generators
63 changes: 42 additions & 21 deletions src/models/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,20 +279,21 @@ SessionInfo::SessionInfo(OrtSession& session) {

void SessionInfo::Add(OrtSession& session) {
auto input_names = session.GetInputNames();
std::vector<ONNXTensorElementDataType> input_types(input_names.size());
for (size_t i = 0; i < input_types.size(); i++) {
auto input_type = session.GetInputTypeInfo(i)->GetTensorTypeAndShapeInfo().GetElementType();
for (size_t i = 0; i < input_names.size(); i++) {
auto type_info = session.GetInputTypeInfo(i);
auto input_type = type_info->GetTensorTypeAndShapeInfo().GetElementType();
auto found_input = inputs_.find(input_names[i]);
if (found_input != inputs_.end() && found_input->second != input_type)
throw std::runtime_error("Model input type mismatch: " + input_names[i] + " expected " + std::to_string(found_input->second) + " got " + std::to_string(input_type));
inputs_.emplace(std::make_pair(std::move(input_names[i]), input_type));
if (found_input != inputs_.end() && found_input->second->GetTensorTypeAndShapeInfo().GetElementType() != input_type)
throw std::runtime_error("Model input type mismatch: " + input_names[i] + " expected " +
std::to_string(found_input->second->GetTensorTypeAndShapeInfo().GetElementType()) +
" got " + std::to_string(input_type));
inputs_.emplace(std::make_pair(std::move(input_names[i]), std::move(type_info)));
}

auto output_names = session.GetOutputNames();
std::vector<ONNXTensorElementDataType> output_types(output_names.size());
for (size_t i = 0; i < output_types.size(); i++) {
auto output_type = session.GetOutputTypeInfo(i)->GetTensorTypeAndShapeInfo().GetElementType();
outputs_.emplace(std::make_pair(std::move(output_names[i]), output_type));
for (size_t i = 0; i < output_names.size(); i++) {
auto type_info = session.GetOutputTypeInfo(i);
outputs_.emplace(std::make_pair(std::move(output_names[i]), std::move(type_info)));
}
}

Expand All @@ -308,14 +309,14 @@ ONNXTensorElementDataType SessionInfo::GetInputDataType(const std::string& name)
auto result = inputs_.find(name);
if (result == inputs_.end())
throw std::runtime_error("Model input was not found: " + name);
return result->second;
return result->second->GetTensorTypeAndShapeInfo().GetElementType();
}

ONNXTensorElementDataType SessionInfo::GetOutputDataType(const std::string& name) const {
auto result = outputs_.find(name);
if (result == outputs_.end())
throw std::runtime_error("Model output was not found: " + name);
return result->second;
return result->second->GetTensorTypeAndShapeInfo().GetElementType();
}

std::vector<std::string> SessionInfo::GetInputNames() const {
Expand All @@ -326,6 +327,20 @@ std::vector<std::string> SessionInfo::GetInputNames() const {
return names;
}

std::vector<const char*> SessionInfo::GetInputSymbolicShape(const std::string& name) const {
auto type_info = inputs_.find(name);
if (type_info == inputs_.end())
throw std::runtime_error("Model input was not found: " + name);
return type_info->second->GetTensorTypeAndShapeInfo().GetSymbolicDimensions();
}

std::vector<const char*> SessionInfo::GetOutputSymbolicShape(const std::string& name) const {
auto type_info = outputs_.find(name);
if (type_info == outputs_.end())
throw std::runtime_error("Model output was not found: " + name);
return type_info->second->GetTensorTypeAndShapeInfo().GetSymbolicDimensions();
}

Model::Model(std::unique_ptr<Config> config) : config_{std::move(config)} {
CreateSessionOptions();
}
Expand Down Expand Up @@ -568,7 +583,9 @@ std::shared_ptr<Model> CreateModel(OrtEnv& ort_env, const char* config_path, con
}

std::shared_ptr<Model> CreateModel(OrtEnv& ort_env, std::unique_ptr<Config> config) {
std::set<std::string> llm_types = {"chatglm", "decoder", "gemma", "gemma2", "gemma3_text", "granite", "llama", "mistral", "nemotron", "olmo", "phi", "phimoe", "phi3", "phi3small", "qwen2"};
std::set<std::string> llm_types = {"chatglm", "decoder", "gemma", "gemma2", "gemma3_text",
"granite", "llama", "mistral", "nemotron", "olmo",
"phi", "phimoe", "phi3", "phi3small", "qwen2"};
if (config->model.type == "gpt2")
return std::make_shared<Gpt_Model>(std::move(config), ort_env);
if (llm_types.find(config->model.type) != llm_types.end())
Expand All @@ -581,6 +598,8 @@ std::shared_ptr<Model> CreateModel(OrtEnv& ort_env, std::unique_ptr<Config> conf
return std::make_shared<DecoderOnlyPipelineModel>(std::move(config), ort_env);
if (config->model.type == "phi4mm")
return std::make_shared<MultiModalLanguageModel>(std::move(config), ort_env, true, true);
if (config->model.type == "gemma3")
return std::make_shared<MultiModalLanguageModel>(std::move(config), ort_env, true, false);

throw std::runtime_error("Unsupported model_type in config.json: " + config->model.type);
}
Expand Down Expand Up @@ -651,15 +670,17 @@ std::unique_ptr<OrtValue> Model::ExpandInputs(std::unique_ptr<OrtValue>& input,
}

MultiModalProcessor::MultiModalProcessor(Config& config, const SessionInfo& session_info)
: tokenizer_{std::make_shared<Tokenizer>(config)} {
if (config.model.type == "phi3v") {
processor_ = std::make_shared<PhiImageProcessor>(config, session_info);
} else if (config.model.type == "whisper") {
processor_ = std::make_shared<WhisperProcessor>(config, session_info);
} else if (config.model.type == "phi4mm") {
processor_ = std::make_shared<PhiMultiModalProcessor>(config, session_info);
: tokenizer_{std::make_shared<Tokenizer>(config)},
processor_factory_{
{"phi3v", Processor::Create<PhiImageProcessor>},
{"whisper", Processor::Create<WhisperProcessor>},
{"phi4mm", Processor::Create<PhiMultiModalProcessor>},
{"gemma3", Processor::Create<GemmaImageProcessor>}} {
auto processor = processor_factory_.find(config.model.type);
if (processor != processor_factory_.end()) {
processor_ = processor->second(config, session_info);
} else {
throw std::runtime_error("MultiModalProcessor cannot be created. Expected a multimodal model. Actual: " + config.model.type);
throw std::runtime_error("MultiModalProcessor cannot be created. " + config.model.type + " is not a registered multi-modal model type.");
}
}

Expand Down
Loading
Loading