From 8be60d1f923a1310905bb35b4ea26ad918a04409 Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Thu, 3 Apr 2025 17:45:42 +0000 Subject: [PATCH 1/3] Support for gemma3 model --- examples/python/model-qa.py | 25 ++-- examples/python/{phi3v.py => model-vision.py} | 19 ++- src/config.h | 1 + src/models/gemma_image_processor.cpp | 125 ++++++++++++++++++ src/models/gemma_image_processor.h | 20 +++ src/models/model.cpp | 63 ++++++--- src/models/model.h | 9 +- src/models/multi_modal.cpp | 43 ++++-- src/models/multi_modal.h | 16 ++- src/models/multi_modal_features.cpp | 16 ++- src/models/multi_modal_features.h | 2 +- src/models/phi_image_processor.h | 6 - src/models/phi_multimodal_processor.h | 3 - src/models/processor.h | 8 ++ src/models/whisper_processor.h | 3 - 15 files changed, 289 insertions(+), 70 deletions(-) rename examples/python/{phi3v.py => model-vision.py} (87%) create mode 100644 src/models/gemma_image_processor.cpp create mode 100644 src/models/gemma_image_processor.h diff --git a/examples/python/model-qa.py b/examples/python/model-qa.py index efa0c188be..bf46bc153b 100644 --- a/examples/python/model-qa.py +++ b/examples/python/model-qa.py @@ -1,7 +1,9 @@ -import onnxruntime_genai as og +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License + +import onnxruntime_genai as og import argparse import time -import numpy as np def main(args): if args.verbose: print("Loading model...") @@ -45,14 +47,16 @@ def main(args): raise ValueError("Chat template must have exactly one pair of curly braces with input word in it, e.g. '<|user|>\n{input} <|end|>\n<|assistant|>'") else: if model_type.startswith("phi4"): - args.chat_template = '<|im_start|>user<|im_sep|>\n{input}<|im_end|>\n<|im_start|>assistant<|im_sep|>' + args.chat_template = '{system_prompt}<|im_start|>user<|im_sep|>\n{input}<|im_end|>\n<|im_start|>assistant<|im_sep|>' elif model_type.startswith("phi"): # For Phi2 and Phi3 - args.chat_template = '<|user|>\n{input} <|end|>\n<|assistant|>' + args.chat_template = '{system_prompt}<|user|>\n{input} <|end|>\n<|assistant|>' elif model_type.startswith("llama"): - args.chat_template = '<|start_header_id|>user<|end_header_id|>\n{input}<|eot_id|><|start_header_id|>assistant<|end_header_id|>' + args.chat_template = '{system_prompt}<|start_header_id|>user<|end_header_id|>\n{input}<|eot_id|><|start_header_id|>assistant<|end_header_id|>' print("Using Chat Template for LLAMA 3, if you are using LLAMA 2 please pass the argument --chat_template '{input} [/INST]')") elif model_type.startswith("qwen2"): - args.chat_template = '<|im_start|>user\n{input}<|im_end|>\n<|im_start|>assistant\n' + args.chat_template = '{system_prompt}<|im_start|>user\n{input}<|im_end|>\n<|im_start|>assistant\n' + elif model_type == "gemma3_text": + args.chat_template = 'user\n{system_prompt}{input}\nmodel\n' else: raise ValueError(f"Chat Template for model type {model_type} is not known. Please provide chat template using --chat_template") @@ -70,12 +74,11 @@ def main(args): print("Using System Prompt for LLAMA 3, if you are using LLAMA 2 please pass the argument --system_prompt '[INST] <>\\n{args.system_prompt}\\n<>')") elif model_type.startswith("qwen2"): system_prompt = f"<|im_start|>system\n{args.system_prompt}<|im_end|>\n" + elif model_type == "gemma3_text": + system_prompt = f"{args.system_prompt}" else: system_prompt = args.system_prompt - system_tokens = tokenizer.encode(system_prompt) - system_prompt_length = len(system_tokens) - # Keep asking for input prompts in a loop while True: if args.input_prompt: @@ -91,7 +94,7 @@ def main(args): if args.timings: started_timestamp = time.time() - prompt = f'{args.chat_template.format(input=text)}' + prompt = f'{args.chat_template.format(system_prompt=system_prompt, input=text)}' input_tokens = tokenizer.encode(prompt) params = og.GeneratorParams(model) @@ -100,7 +103,7 @@ def main(args): if args.verbose: print("Generator created") # Append system and input tokens to the generator - generator.append_tokens(np.concatenate([system_tokens, input_tokens])) + generator.append_tokens(input_tokens) if args.verbose: print("Running generation loop ...") if args.timings: diff --git a/examples/python/phi3v.py b/examples/python/model-vision.py similarity index 87% rename from examples/python/phi3v.py rename to examples/python/model-vision.py index 18e264426c..a784cbff4e 100644 --- a/examples/python/phi3v.py +++ b/examples/python/model-vision.py @@ -9,7 +9,6 @@ import onnxruntime_genai as og - def _find_dir_contains_sub_dir(current_dir: Path, target_dir_name): curr_path = Path(current_dir).absolute() target_dir = glob.glob(target_dir_name, root_dir=curr_path) @@ -65,8 +64,20 @@ def run(args: argparse.Namespace): image_paths = [image_path for image_path in image_paths if image_path] + user_tag = image_tag = assistant_tag = end_tag = "" + if model.type == "phi3v": + user_tag = "<|user|>\n" + image_tag = "<|image_{image_id}|>\n" + assistant_tag = "<|assistant|>\n" + end_tag = "<|end|>\n" + elif model.type == "gemma3": + user_tag = "user\n" + image_tag = "" + assistant_tag = "model\n" + end_tag = "\n" + images = None - prompt = "<|user|>\n" + prompt = f"{user_tag}" if len(image_paths) == 0: print("No image provided") else: @@ -74,7 +85,7 @@ def run(args: argparse.Namespace): if not os.path.exists(image_path): raise FileNotFoundError(f"Image file not found: {image_path}") print(f"Using image: {image_path}") - prompt += f"<|image_{i+1}|>\n" + prompt += f"{image_tag.replace('{image_id}', str(i+1))}" images = og.Images.open(*image_paths) @@ -85,7 +96,7 @@ def run(args: argparse.Namespace): text = args.prompt else: text = "What is shown in this image?" - prompt += f"{text}<|end|>\n<|assistant|>\n" + prompt += f"{text}{end_tag}{assistant_tag}" print("Processing images and prompt...") inputs = processor(prompt, images=images) diff --git a/src/config.h b/src/config.h index 5a7998ec86..7f30320718 100644 --- a/src/config.h +++ b/src/config.h @@ -26,6 +26,7 @@ struct Config { static constexpr std::string_view PastSequenceLengthName = "past_sequence_length"; static constexpr std::string_view promptTemplate = "{Content}"; static constexpr std::string_view TotalSequenceLengthName = "total_sequence_length"; + static constexpr std::string_view TokenTypeIdsName = "token_type_ids"; // Vision names static constexpr std::string_view PixelValuesName = "pixel_values"; diff --git a/src/models/gemma_image_processor.cpp b/src/models/gemma_image_processor.cpp new file mode 100644 index 0000000000..f51875c72c --- /dev/null +++ b/src/models/gemma_image_processor.cpp @@ -0,0 +1,125 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "../generators.h" +#include "model.h" + +#include + +namespace Generators { + +namespace { + +std::tuple, std::unique_ptr, std::unique_ptr> +ProcessImagePrompt(const Generators::Tokenizer& tokenizer, const std::string& prompt, + OrtxTensor* pixel_values, Ort::Allocator& allocator) { + constexpr char boi_token[] = ""; + constexpr char image_token[] = ""; + constexpr char eoi_token[] = ""; + constexpr size_t image_seq_length = 256; + + int64_t num_images{}; + if (pixel_values) { + const float* pixel_values_data{}; + const int64_t* pixel_values_shape{}; + size_t pixel_values_num_dims; + CheckResult(OrtxGetTensorData(pixel_values, reinterpret_cast(&pixel_values_data), + &pixel_values_shape, &pixel_values_num_dims)); + num_images = pixel_values_shape[0]; + } + + // Generate input_ids and token_type_ids + std::string text = prompt; + if (text.empty()) { + for (int64_t i = 0; i < num_images; ++i) { + text += " "; + } + text.pop_back(); + } + + // Count the number of boi tokens and make sure it matches the number of images + const std::regex boi_regex{std::string(boi_token)}; + const auto boi_begin = std::sregex_iterator(text.begin(), text.end(), boi_regex); + const auto boi_end = std::sregex_iterator(); + const auto boi_tokens = std::distance(boi_begin, boi_end); + if (num_images != boi_tokens) { + throw std::runtime_error("Prompt contained " + std::to_string(boi_tokens) + " image tokens but received " + + std::to_string(num_images) + " images."); + } + + std::string image_tokens_expanded{}; + for (size_t i = 0; i < image_seq_length; ++i) { + image_tokens_expanded += image_token; + } + const std::string full_image_sequence = std::string("\n\n") + boi_token + image_tokens_expanded + eoi_token + std::string("\n\n"); + + text = std::regex_replace(text, boi_regex, full_image_sequence); + + const std::vector input_ids = tokenizer.Encode(text.c_str()); + + std::unique_ptr input_ids_value = OrtValue::CreateTensor(allocator, std::vector{1, static_cast(input_ids.size())}); + std::copy(input_ids.begin(), input_ids.end(), input_ids_value->GetTensorMutableData()); + + std::unique_ptr token_type_ids = OrtValue::CreateTensor(allocator, std::vector{1, static_cast(input_ids.size())}); + const auto image_token_id = tokenizer.TokenToTokenId(image_token); + for (size_t i = 0; i < input_ids.size(); ++i) { + if (input_ids[i] == image_token_id) { + token_type_ids->GetTensorMutableData()[i] = 1; + } else { + token_type_ids->GetTensorMutableData()[i] = 0; + } + } + + std::unique_ptr num_img_tokens = OrtValue::CreateTensor(allocator, std::vector{1}); + num_img_tokens->GetTensorMutableData()[0] = static_cast(image_seq_length); + + return {std::move(input_ids_value), std::move(token_type_ids), std::move(num_img_tokens)}; +} + +} // namespace + +GemmaImageProcessor::GemmaImageProcessor(Config& config, const SessionInfo& session_info) + : pixel_values_type_{session_info.GetInputDataType(config.model.vision.inputs.pixel_values)} { + const auto processor_config = (config.config_path / fs::path(config.model.vision.config_filename)).string(); + CheckResult(OrtxCreateProcessor(processor_.ToBeAssigned(), processor_config.c_str())); + + config.AddMapping(std::string(Config::Defaults::InputIdsName), config.model.embedding.inputs.input_ids); + config.AddMapping(std::string(Config::Defaults::PixelValuesName), config.model.vision.inputs.pixel_values); +} + +std::unique_ptr GemmaImageProcessor::Process(const Tokenizer& tokenizer, const Payload& payload) const { + std::string prompt = std::string(payload.prompt); + const Images* images = payload.images; + Ort::Allocator& allocator{Ort::Allocator::GetWithDefaultOptions()}; + auto named_tensors = std::make_unique(); + + if (!images) { + [[maybe_unused]] auto [input_ids, token_type_ids, num_img_tokens] = ProcessImagePrompt(tokenizer, prompt, nullptr, allocator); + named_tensors->emplace(Config::Defaults::InputIdsName, std::make_shared(std::move(input_ids))); + return named_tensors; + } + + ort_extensions::OrtxObjectPtr result; + CheckResult(OrtxImagePreProcess(processor_.get(), images->images_.get(), result.ToBeAssigned())); + + OrtxTensor* pixel_values = nullptr; + CheckResult(OrtxTensorResultGetAt(result.get(), 0, &pixel_values)); + + auto [input_ids, token_type_ids, num_img_tokens] = ProcessImagePrompt(tokenizer, prompt, pixel_values, allocator); + named_tensors->emplace(std::string(Config::Defaults::InputIdsName), std::make_shared(std::move(input_ids))); + named_tensors->emplace(std::string(Config::Defaults::TokenTypeIdsName), std::make_shared(std::move(token_type_ids))); + + if (pixel_values_type_ == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) { + named_tensors->emplace(std::string(Config::Defaults::PixelValuesName), + std::make_shared(ProcessTensor(pixel_values, allocator))); + } else { + named_tensors->emplace(std::string(Config::Defaults::PixelValuesName), + std::make_shared(ProcessTensor(pixel_values, allocator))); + } + + named_tensors->emplace(std::string(Config::Defaults::NumImageTokens), std::make_shared(std::move(num_img_tokens))); + + return named_tensors; +} + +} // namespace Generators diff --git a/src/models/gemma_image_processor.h b/src/models/gemma_image_processor.h new file mode 100644 index 0000000000..a6b951e242 --- /dev/null +++ b/src/models/gemma_image_processor.h @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include "processor.h" + +namespace Generators { + +struct GemmaImageProcessor : Processor { + GemmaImageProcessor(Config& config, const SessionInfo& session_info); + + virtual std::unique_ptr Process(const Tokenizer& tokenizer, const Payload& payload) const override; + + private: + ort_extensions::OrtxObjectPtr processor_; + + ONNXTensorElementDataType pixel_values_type_; +}; + +} // namespace Generators diff --git a/src/models/model.cpp b/src/models/model.cpp index 6b0ecc4963..128a583b4c 100644 --- a/src/models/model.cpp +++ b/src/models/model.cpp @@ -279,20 +279,21 @@ SessionInfo::SessionInfo(OrtSession& session) { void SessionInfo::Add(OrtSession& session) { auto input_names = session.GetInputNames(); - std::vector input_types(input_names.size()); - for (size_t i = 0; i < input_types.size(); i++) { - auto input_type = session.GetInputTypeInfo(i)->GetTensorTypeAndShapeInfo().GetElementType(); + for (size_t i = 0; i < input_names.size(); i++) { + auto type_info = session.GetInputTypeInfo(i); + auto input_type = type_info->GetTensorTypeAndShapeInfo().GetElementType(); auto found_input = inputs_.find(input_names[i]); - if (found_input != inputs_.end() && found_input->second != input_type) - throw std::runtime_error("Model input type mismatch: " + input_names[i] + " expected " + std::to_string(found_input->second) + " got " + std::to_string(input_type)); - inputs_.emplace(std::make_pair(std::move(input_names[i]), input_type)); + if (found_input != inputs_.end() && found_input->second->GetTensorTypeAndShapeInfo().GetElementType() != input_type) + throw std::runtime_error("Model input type mismatch: " + input_names[i] + " expected " + + std::to_string(found_input->second->GetTensorTypeAndShapeInfo().GetElementType()) + + " got " + std::to_string(input_type)); + inputs_.emplace(std::make_pair(std::move(input_names[i]), std::move(type_info))); } auto output_names = session.GetOutputNames(); - std::vector output_types(output_names.size()); - for (size_t i = 0; i < output_types.size(); i++) { - auto output_type = session.GetOutputTypeInfo(i)->GetTensorTypeAndShapeInfo().GetElementType(); - outputs_.emplace(std::make_pair(std::move(output_names[i]), output_type)); + for (size_t i = 0; i < output_names.size(); i++) { + auto type_info = session.GetOutputTypeInfo(i); + outputs_.emplace(std::make_pair(std::move(output_names[i]), std::move(type_info))); } } @@ -308,14 +309,14 @@ ONNXTensorElementDataType SessionInfo::GetInputDataType(const std::string& name) auto result = inputs_.find(name); if (result == inputs_.end()) throw std::runtime_error("Model input was not found: " + name); - return result->second; + return result->second->GetTensorTypeAndShapeInfo().GetElementType(); } ONNXTensorElementDataType SessionInfo::GetOutputDataType(const std::string& name) const { auto result = outputs_.find(name); if (result == outputs_.end()) throw std::runtime_error("Model output was not found: " + name); - return result->second; + return result->second->GetTensorTypeAndShapeInfo().GetElementType(); } std::vector SessionInfo::GetInputNames() const { @@ -326,6 +327,20 @@ std::vector SessionInfo::GetInputNames() const { return names; } +std::vector SessionInfo::GetInputSymbolicShape(const std::string& name) const { + auto type_info = inputs_.find(name); + if (type_info == inputs_.end()) + throw std::runtime_error("Model input was not found: " + name); + return type_info->second->GetTensorTypeAndShapeInfo().GetSymbolicDimensions(); +} + +std::vector SessionInfo::GetOutputSymbolicShape(const std::string& name) const { + auto type_info = outputs_.find(name); + if (type_info == outputs_.end()) + throw std::runtime_error("Model output was not found: " + name); + return type_info->second->GetTensorTypeAndShapeInfo().GetSymbolicDimensions(); +} + Model::Model(std::unique_ptr config) : config_{std::move(config)} { CreateSessionOptions(); } @@ -568,7 +583,9 @@ std::shared_ptr CreateModel(OrtEnv& ort_env, const char* config_path, con } std::shared_ptr CreateModel(OrtEnv& ort_env, std::unique_ptr config) { - std::set llm_types = {"chatglm", "decoder", "gemma", "gemma2", "gemma3_text", "granite", "llama", "mistral", "nemotron", "olmo", "phi", "phimoe", "phi3", "phi3small", "qwen2"}; + std::set llm_types = {"chatglm", "decoder", "gemma", "gemma2", "granite", + "llama", "mistral", "nemotron", "olmo", "phi", + "phimoe", "phi3", "phi3small", "qwen2", "gemma3_text"}; if (config->model.type == "gpt2") return std::make_shared(std::move(config), ort_env); if (llm_types.find(config->model.type) != llm_types.end()) @@ -581,6 +598,8 @@ std::shared_ptr CreateModel(OrtEnv& ort_env, std::unique_ptr conf return std::make_shared(std::move(config), ort_env); if (config->model.type == "phi4mm") return std::make_shared(std::move(config), ort_env, true, true); + if (config->model.type == "gemma3") + return std::make_shared(std::move(config), ort_env, true, false); throw std::runtime_error("Unsupported model_type in config.json: " + config->model.type); } @@ -651,15 +670,17 @@ std::unique_ptr Model::ExpandInputs(std::unique_ptr& input, } MultiModalProcessor::MultiModalProcessor(Config& config, const SessionInfo& session_info) - : tokenizer_{std::make_shared(config)} { - if (config.model.type == "phi3v") { - processor_ = std::make_shared(config, session_info); - } else if (config.model.type == "whisper") { - processor_ = std::make_shared(config, session_info); - } else if (config.model.type == "phi4mm") { - processor_ = std::make_shared(config, session_info); + : tokenizer_{std::make_shared(config)}, + processor_factory_{ + {"phi3v", Processor::Create}, + {"whisper", Processor::Create}, + {"phi4mm", Processor::Create}, + {"gemma3", Processor::Create}} { + auto processor = processor_factory_.find(config.model.type); + if (processor != processor_factory_.end()) { + processor_ = processor->second(config, session_info); } else { - throw std::runtime_error("MultiModalProcessor cannot be created. Expected a multimodal model. Actual: " + config.model.type); + throw std::runtime_error("MultiModalProcessor cannot be created. " + config.model.type + " is not a registered multi-modal model type."); } } diff --git a/src/models/model.h b/src/models/model.h index 85a9635daa..6ae4cd8d0f 100644 --- a/src/models/model.h +++ b/src/models/model.h @@ -7,6 +7,7 @@ #include "phi_image_processor.h" #include "whisper_processor.h" #include "phi_multimodal_processor.h" +#include "gemma_image_processor.h" #include "adapters.h" #include "extra_outputs.h" @@ -99,6 +100,9 @@ struct MultiModalProcessor : std::enable_shared_from_this, std::shared_ptr tokenizer_; std::shared_ptr processor_; + + private: + std::unordered_map(Config&, const SessionInfo&)>> processor_factory_; }; struct SessionInfo { @@ -114,8 +118,11 @@ struct SessionInfo { std::vector GetInputNames() const; + std::vector GetInputSymbolicShape(const std::string& name) const; + std::vector GetOutputSymbolicShape(const std::string& name) const; + private: - std::unordered_map inputs_, outputs_; + std::unordered_map> inputs_, outputs_; }; struct Model : std::enable_shared_from_this, LeakChecked, ExternalRefCounted { diff --git a/src/models/multi_modal.cpp b/src/models/multi_modal.cpp index edb3873c65..143a43ac39 100644 --- a/src/models/multi_modal.cpp +++ b/src/models/multi_modal.cpp @@ -41,6 +41,22 @@ int64_t GetNumAudioTokens(const std::vector& extra_input return 0; } +int64_t GetImageFeatureBatchSize(const std::vector& extra_inputs) { + for (size_t i = 0; i < extra_inputs.size(); ++i) { + if (extra_inputs[i].name == Config::Defaults::PixelValuesName) { + assert(extra_inputs[i].tensor->ort_tensor_); + const auto num_dims = extra_inputs[i].tensor->ort_tensor_->GetTensorTypeAndShapeInfo()->GetShape().size(); + if (num_dims < 3) { + return 0; + } + // If image features have rank 3, the batch size is the first dimension + return extra_inputs[i].tensor->ort_tensor_->GetTensorTypeAndShapeInfo()->GetShape().front(); + } + } + + return 0; +} + } // namespace MultiModalLanguageModel::MultiModalLanguageModel(std::unique_ptr config, OrtEnv& ort_env, bool vision, bool speech) @@ -82,12 +98,17 @@ std::unique_ptr MultiModalLanguageModel::CreateState(DeviceSpan return std::make_unique(*this, sequence_lengths, params); } -VisionState::VisionState(const MultiModalLanguageModel& model, const GeneratorParams& params, const int64_t num_image_tokens) +VisionState::VisionState(const MultiModalLanguageModel& model, const GeneratorParams& params, + const int64_t num_images, const int64_t num_image_tokens) : State{params, model}, model_{model}, - num_image_tokens_{num_image_tokens} { + num_image_tokens_{num_image_tokens}, + num_images_{num_images} { extra_inputs_.Add(model_.vision_session_->GetInputNames()); - image_features_.Add(); + image_features_ = std::make_unique(*this, MultiModalFeatures::Mode::Output, // Optional model input + model_.config_->model.vision.outputs.image_features, + num_images_, num_image_tokens_); + image_features_->Add(); } DeviceSpan VisionState::Run(int current_length, DeviceSpan& next_tokens, DeviceSpan next_indices) { @@ -108,7 +129,8 @@ DeviceSpan SpeechState::Run(int current_length, DeviceSpan& next return {}; } -EmbeddingState::EmbeddingState(const MultiModalLanguageModel& model, const GeneratorParams& params, const int64_t num_image_tokens, const int64_t num_audio_tokens) +EmbeddingState::EmbeddingState(const MultiModalLanguageModel& model, const GeneratorParams& params, + const int64_t num_images, const int64_t num_image_tokens, const int64_t num_audio_tokens) : State{params, model}, model_{model}, num_image_tokens_{num_image_tokens}, @@ -117,13 +139,13 @@ EmbeddingState::EmbeddingState(const MultiModalLanguageModel& model, const Gener if (model_.vision_session_) { image_features_ = std::make_unique(*this, MultiModalFeatures::Mode::Input, // Optional model input model_.config_->model.embedding.inputs.image_features, - num_image_tokens_); + num_images, num_image_tokens_); image_features_->Add(); } if (model_.speech_session_) { audio_features_ = std::make_unique(*this, MultiModalFeatures::Mode::Input, // Optional model input model_.config_->model.embedding.inputs.audio_features, - num_audio_tokens_); + -1, num_audio_tokens_); audio_features_->Add(); } inputs_embeds_.Add(); @@ -170,14 +192,15 @@ MultiModalPipelineState::MultiModalPipelineState(const MultiModalLanguageModel& model_{model}, num_image_tokens_{GetNumImageTokens(params_->extra_inputs)}, num_audio_tokens_{GetNumAudioTokens(params_->extra_inputs, model_.config_->model.speech.inputs.audio_sizes)}, - adapters_{std::make_shared(&model_)} { + adapters_{std::make_shared(&model_)}, + num_images_{GetImageFeatureBatchSize(params_->extra_inputs)} { if (model_.vision_session_) { - vision_state_ = std::make_unique(model_, params, num_image_tokens_); + vision_state_ = std::make_unique(model_, params, num_images_, num_image_tokens_); } if (model_.speech_session_) { speech_state_ = std::make_unique(model_, params, num_audio_tokens_); } - embedding_state_ = std::make_unique(model, params, num_image_tokens_, num_audio_tokens_); + embedding_state_ = std::make_unique(model, params, num_images_, num_image_tokens_, num_audio_tokens_); decoder_state_ = std::make_unique(model_, sequence_lengths, params); if (vision_state_ != nullptr && model_.config_->model.vision.adapter_filename.has_value() && num_image_tokens_ > 0) { @@ -214,7 +237,7 @@ DeviceSpan MultiModalPipelineState::Run(int current_length, DeviceSpan 0 && speech_state_) { speech_state_->Run(current_length, next_tokens, next_indices); } - if (vision_state_) embedding_state_->image_features_->ReuseFeaturesBuffer(vision_state_->image_features_); + if (vision_state_) embedding_state_->image_features_->ReuseFeaturesBuffer(*vision_state_->image_features_); if (speech_state_) embedding_state_->audio_features_->ReuseFeaturesBuffer(speech_state_->audio_features_); embedding_state_->inputs_embeds_.ReuseEmbeddingsBuffer(decoder_state_->inputs_embeds_); embedding_state_->Run(current_length, next_tokens, next_indices); diff --git a/src/models/multi_modal.h b/src/models/multi_modal.h index 9310e66dc1..a45a623cbd 100644 --- a/src/models/multi_modal.h +++ b/src/models/multi_modal.h @@ -27,7 +27,8 @@ struct MultiModalLanguageModel : Model { }; struct VisionState : State { - VisionState(const MultiModalLanguageModel& model, const GeneratorParams& params, const int64_t num_image_tokens); + VisionState(const MultiModalLanguageModel& model, const GeneratorParams& params, + const int64_t num_images, const int64_t num_image_tokens); VisionState(const VisionState&) = delete; VisionState& operator=(const VisionState&) = delete; @@ -38,10 +39,9 @@ struct VisionState : State { const MultiModalLanguageModel& model_; int64_t num_image_tokens_; - ExtraInputs extra_inputs_{*this}; // Model inputs - MultiModalFeatures image_features_{*this, MultiModalFeatures::Mode::Output, // Model output - model_.config_->model.vision.outputs.image_features, - num_image_tokens_}; + int64_t num_images_{}; + ExtraInputs extra_inputs_{*this}; // Model inputs + std::unique_ptr image_features_; }; struct SpeechState : State { @@ -59,11 +59,12 @@ struct SpeechState : State { ExtraInputs extra_inputs_{*this}; // Model inputs MultiModalFeatures audio_features_{*this, MultiModalFeatures::Mode::Output, // Model output model_.config_->model.speech.outputs.audio_features, - num_audio_tokens_}; + -1, num_audio_tokens_}; }; struct EmbeddingState : State { - EmbeddingState(const MultiModalLanguageModel& model, const GeneratorParams& params, const int64_t num_image_tokens, const int64_t num_audio_tokens); + EmbeddingState(const MultiModalLanguageModel& model, const GeneratorParams& params, + const int64_t num_images, const int64_t num_image_tokens, const int64_t num_audio_tokens); EmbeddingState(const EmbeddingState&) = delete; EmbeddingState& operator=(const EmbeddingState&) = delete; @@ -124,6 +125,7 @@ struct MultiModalPipelineState : State { const MultiModalLanguageModel& model_; int64_t num_image_tokens_{}; int64_t num_audio_tokens_{}; + int64_t num_images_{}; std::unique_ptr vision_state_; std::unique_ptr speech_state_; std::unique_ptr embedding_state_; diff --git a/src/models/multi_modal_features.cpp b/src/models/multi_modal_features.cpp index 77f3f5acce..ea43143e8f 100644 --- a/src/models/multi_modal_features.cpp +++ b/src/models/multi_modal_features.cpp @@ -6,13 +6,23 @@ namespace Generators { -MultiModalFeatures::MultiModalFeatures(State& state, MultiModalFeatures::Mode mode, const std::string& name, int64_t num_feature_tokens) +MultiModalFeatures::MultiModalFeatures(State& state, MultiModalFeatures::Mode mode, const std::string& name, + int64_t batch_size, int64_t num_feature_tokens) : state_{state}, type_{mode == MultiModalFeatures::Mode::Input ? model_.session_info_->GetInputDataType(name) : model_.session_info_->GetOutputDataType(name)}, mode_{mode}, name_{name} { + const auto dims = mode_ == MultiModalFeatures::Mode::Input + ? model_.session_info_->GetInputSymbolicShape(name).size() + : model_.session_info_->GetOutputSymbolicShape(name).size(); + + // If the model expects 3 dimensions, add a batch dimension + if (dims == 3) { + shape_.push_back(1); + } + shape_.push_back(num_feature_tokens); shape_.push_back(model_.config_->model.decoder.hidden_size); @@ -50,8 +60,8 @@ void MultiModalFeatures::Add() { void MultiModalFeatures::Update(bool is_prompt) { // Initialize empty features tensor for after-prompt input scenarios // num_feature_tokens will be 0 when no image is provided - if (!is_prompt && shape_[0] > 0) { // if num_image_tokens > 0 - shape_[0] = 0; + if (!is_prompt && shape_[shape_.size() - 2] > 0) { // if num_image_tokens > 0 + shape_[shape_.size() - 2] = 0; features_ = OrtValue::CreateTensor(model_.p_device_->GetAllocator(), shape_, type_); state_.inputs_[index_] = features_.get(); } diff --git a/src/models/multi_modal_features.h b/src/models/multi_modal_features.h index 2ecab0d523..f61c6580f8 100644 --- a/src/models/multi_modal_features.h +++ b/src/models/multi_modal_features.h @@ -11,7 +11,7 @@ struct MultiModalFeatures { Output }; - MultiModalFeatures(State& state, MultiModalFeatures::Mode mode, const std::string& name, int64_t num_feature_tokens); + MultiModalFeatures(State& state, MultiModalFeatures::Mode mode, const std::string& name, int64_t batch_size, int64_t num_feature_tokens); MultiModalFeatures(const MultiModalFeatures&) = delete; MultiModalFeatures& operator=(const MultiModalFeatures&) = delete; diff --git a/src/models/phi_image_processor.h b/src/models/phi_image_processor.h index 6390e30c92..f72a2a447e 100644 --- a/src/models/phi_image_processor.h +++ b/src/models/phi_image_processor.h @@ -6,9 +6,6 @@ namespace Generators { -struct Config; -struct SessionInfo; - struct PhiImageProcessor : Processor { PhiImageProcessor(Config& config, const SessionInfo& session_info); @@ -21,10 +18,7 @@ struct PhiImageProcessor : Processor { private: ort_extensions::OrtxObjectPtr processor_; - std::string input_ids_name_; - std::string pixel_values_name_; ONNXTensorElementDataType pixel_values_type_; - std::string image_sizes_name_; }; } // namespace Generators diff --git a/src/models/phi_multimodal_processor.h b/src/models/phi_multimodal_processor.h index 3826596d89..39b64eb01c 100644 --- a/src/models/phi_multimodal_processor.h +++ b/src/models/phi_multimodal_processor.h @@ -6,9 +6,6 @@ namespace Generators { -struct Config; -struct SessionInfo; - struct PhiMultiModalProcessor : Processor { PhiMultiModalProcessor(Config& config, const SessionInfo& session_info); diff --git a/src/models/processor.h b/src/models/processor.h index 7a621013a9..66b29cfc83 100644 --- a/src/models/processor.h +++ b/src/models/processor.h @@ -44,6 +44,9 @@ struct Payload { const Audios* audios; }; +struct Config; +struct SessionInfo; + template std::unique_ptr ProcessTensor(OrtxTensor* tensor, Ort::Allocator& allocator); @@ -54,6 +57,11 @@ struct Processor { Processor(const Processor&) = delete; Processor& operator=(const Processor&) = delete; + template + static std::shared_ptr Create(Config& config, const SessionInfo& session_info) { + return std::make_shared(config, session_info); + } + virtual std::unique_ptr Process(const Tokenizer& tokenizer, const Payload& payload) const = 0; }; diff --git a/src/models/whisper_processor.h b/src/models/whisper_processor.h index 251ac59e24..335410d099 100644 --- a/src/models/whisper_processor.h +++ b/src/models/whisper_processor.h @@ -6,9 +6,6 @@ namespace Generators { -struct Config; -struct SessionInfo; - struct WhisperProcessor : Processor { WhisperProcessor(Config& config, const SessionInfo& session_info); From e7d61388e5f0b68eaa24d96868e6663ab337ee65 Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Thu, 3 Apr 2025 18:13:10 +0000 Subject: [PATCH 2/3] Address pipeline failures --- src/models/multi_modal.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/models/multi_modal.cpp b/src/models/multi_modal.cpp index 143a43ac39..ca4baa17ec 100644 --- a/src/models/multi_modal.cpp +++ b/src/models/multi_modal.cpp @@ -192,8 +192,8 @@ MultiModalPipelineState::MultiModalPipelineState(const MultiModalLanguageModel& model_{model}, num_image_tokens_{GetNumImageTokens(params_->extra_inputs)}, num_audio_tokens_{GetNumAudioTokens(params_->extra_inputs, model_.config_->model.speech.inputs.audio_sizes)}, - adapters_{std::make_shared(&model_)}, - num_images_{GetImageFeatureBatchSize(params_->extra_inputs)} { + num_images_{GetImageFeatureBatchSize(params_->extra_inputs)}, + adapters_{std::make_shared(&model_)} { if (model_.vision_session_) { vision_state_ = std::make_unique(model_, params, num_images_, num_image_tokens_); } From 329867fe30a7f70316a9208b4247d731109ec4a8 Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Mon, 7 Apr 2025 17:49:26 +0000 Subject: [PATCH 3/3] Address pull-request review feedback --- src/models/model.cpp | 6 +++--- src/models/multi_modal_features.cpp | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/models/model.cpp b/src/models/model.cpp index 128a583b4c..c994be2c1f 100644 --- a/src/models/model.cpp +++ b/src/models/model.cpp @@ -583,9 +583,9 @@ std::shared_ptr CreateModel(OrtEnv& ort_env, const char* config_path, con } std::shared_ptr CreateModel(OrtEnv& ort_env, std::unique_ptr config) { - std::set llm_types = {"chatglm", "decoder", "gemma", "gemma2", "granite", - "llama", "mistral", "nemotron", "olmo", "phi", - "phimoe", "phi3", "phi3small", "qwen2", "gemma3_text"}; + std::set llm_types = {"chatglm", "decoder", "gemma", "gemma2", "gemma3_text", + "granite", "llama", "mistral", "nemotron", "olmo", + "phi", "phimoe", "phi3", "phi3small", "qwen2"}; if (config->model.type == "gpt2") return std::make_shared(std::move(config), ort_env); if (llm_types.find(config->model.type) != llm_types.end()) diff --git a/src/models/multi_modal_features.cpp b/src/models/multi_modal_features.cpp index ea43143e8f..f6f983166a 100644 --- a/src/models/multi_modal_features.cpp +++ b/src/models/multi_modal_features.cpp @@ -20,7 +20,7 @@ MultiModalFeatures::MultiModalFeatures(State& state, MultiModalFeatures::Mode mo // If the model expects 3 dimensions, add a batch dimension if (dims == 3) { - shape_.push_back(1); + shape_.push_back(batch_size); } shape_.push_back(num_feature_tokens);