diff --git a/cmake/deps.txt b/cmake/deps.txt index 7e50996352..fc0530b187 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -14,7 +14,7 @@ pybind11;https://github.com/pybind/pybind11/archive/refs/tags/v2.13.6.zip;f78029 googletest;https://github.com/google/googletest/archive/530d5c8c84abd2a46f38583ee817743c9b3a42b4.zip;5e3a61db2aa975cfd0f97ba92c818744e7fa7034 microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5 directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e -onnxruntime_extensions;https://github.com/microsoft/onnxruntime-extensions.git;245f6667babf9668b862ac4513c69ea95117c295 +onnxruntime_extensions;https://github.com/microsoft/onnxruntime-extensions.git;301b442d8f903daba129e825cd446755b840abb0 # These two dependencies are for the optional constrained decoding feature (USE_GUIDANCE) llguidance;https://github.com/microsoft/llguidance.git;94fa39128ef184ffeda33845f6d333f332a34b4d diff --git a/src/config.cpp b/src/config.cpp index 7087819d86..3e1c535780 100644 --- a/src/config.cpp +++ b/src/config.cpp @@ -615,6 +615,8 @@ struct VisionInputs_Element : JSON::Element { v_.pixel_values = JSON::Get(value); } else if (name == "image_sizes") { v_.image_sizes = JSON::Get(value); + } else if (name == "image_grid_thw") { + v_.image_grid_thw = JSON::Get(value); } else if (name == "attention_mask") { v_.attention_mask = JSON::Get(value); } else { @@ -651,6 +653,10 @@ struct Vision_Element : JSON::Element { v_.config_filename = JSON::Get(value); } else if (name == "adapter_filename") { v_.adapter_filename = JSON::Get(value); + } else if (name == "spatial_merge_size") { + v_.spatial_merge_size = static_cast(JSON::Get(value)); + } else if (name == "tokens_per_second") { + v_.tokens_per_second = static_cast(JSON::Get(value)); } else { throw JSON::unknown_value_error{}; } @@ -856,6 +862,12 @@ struct Model_Element : JSON::Element { v_.decoder_start_token_id = static_cast(JSON::Get(value)); } else if (name == "sep_token_id") { v_.sep_token_id = static_cast(JSON::Get(value)); + } else if (name == "image_token_id") { + v_.image_token_id = static_cast(JSON::Get(value)); + } else if (name == "video_token_id") { + v_.video_token_id = static_cast(JSON::Get(value)); + } else if (name == "vision_start_token_id") { + v_.vision_start_token_id = static_cast(JSON::Get(value)); } else { throw JSON::unknown_value_error{}; } diff --git a/src/config.h b/src/config.h index 507d7c80c1..c03a1d1860 100644 --- a/src/config.h +++ b/src/config.h @@ -38,6 +38,7 @@ struct Config { // Vision encoder names static constexpr std::string_view PixelValuesName = "pixel_values"; static constexpr std::string_view ImageSizesName = "image_sizes"; + static constexpr std::string_view ImageGridThwName = "image_grid_thw"; static constexpr std::string_view ImageAttentionMaskName = "image_attention_mask"; static constexpr std::string_view ImageFeaturesName = "image_features"; static constexpr std::string_view NumImageTokens = "num_image_tokens"; @@ -106,6 +107,12 @@ struct Config { int bos_token_id{}; // The id of the beginning-of-stream token. int sep_token_id{}; // The id of the separation token. int decoder_start_token_id{}; // If an encoder-decoder model starts decoding with a different token than bos, the id of that token. + + // Qwen2-VL specific token IDs + int image_token_id{}; + int video_token_id{}; + int vision_start_token_id{}; + int vocab_size{}; int context_length{}; @@ -159,9 +166,14 @@ struct Config { std::string config_filename{"processor_config.json"}; std::optional adapter_filename{}; + // Qwen2-VL specific vision config values + int spatial_merge_size{2}; + float tokens_per_second{2.0f}; + struct Inputs { std::string pixel_values{Defaults::PixelValuesName}; std::string image_sizes{Defaults::ImageSizesName}; + std::string image_grid_thw{Defaults::ImageGridThwName}; std::string attention_mask{Defaults::ImageAttentionMaskName}; // image attention mask } inputs; diff --git a/src/models/model.cpp b/src/models/model.cpp index 77e6c82657..5aec0ed415 100644 --- a/src/models/model.cpp +++ b/src/models/model.cpp @@ -1288,7 +1288,9 @@ MultiModalProcessor::MultiModalProcessor(Config& config, const SessionInfo& sess {"phi3v", Processor::Create}, {"whisper", Processor::Create}, {"phi4mm", Processor::Create}, - {"gemma3", Processor::Create}} { + {"gemma3", Processor::Create}, + {"qwen2vl", Processor::Create}, + {"qwen2_5_vl", Processor::Create}} { auto processor = processor_factory_.find(config.model.type); if (processor != processor_factory_.end()) { processor_ = processor->second(config, session_info); diff --git a/src/models/model.h b/src/models/model.h index 7faa1000fe..62034bb1a3 100644 --- a/src/models/model.h +++ b/src/models/model.h @@ -9,6 +9,7 @@ #include "whisper_processor.h" #include "phi_multimodal_processor.h" #include "gemma_image_processor.h" +#include "qwen_image_processor.h" #include "adapters.h" #include "extra_outputs.h" @@ -176,4 +177,4 @@ struct Model : std::enable_shared_from_this, LeakChecked, External std::map> pipeline_session_options_; }; -} // namespace Generators +} // namespace Generators \ No newline at end of file diff --git a/src/models/model_type.h b/src/models/model_type.h index c7c4d2f691..0094ab2a11 100644 --- a/src/models/model_type.h +++ b/src/models/model_type.h @@ -18,10 +18,15 @@ struct ModelType { inline static bool IsVLM(const std::string& model_type) { // Vision-language model (VLM) - static constexpr std::array VLM = {"gemma3", "phi3v"}; + static constexpr std::array VLM = {"gemma3", "phi3v", "qwen2vl", "qwen2_5_vl"}; return std::find(VLM.begin(), VLM.end(), model_type) != VLM.end(); } + inline static bool IsQwen2VL(const std::string& model_type) { + // Qwen2-VL specific check for 3D position IDs + return model_type == "qwen2vl" || model_type == "qwen2_5_vl"; + } + inline static bool IsALM(const std::string& model_type) { // Audio-language model (ALM) static constexpr std::array ALM = {"whisper"}; diff --git a/src/models/multi_modal.cpp b/src/models/multi_modal.cpp index 56e55b1552..5abb9f4c2a 100644 --- a/src/models/multi_modal.cpp +++ b/src/models/multi_modal.cpp @@ -3,6 +3,7 @@ #include "../generators.h" #include "multi_modal.h" +#include namespace Generators { @@ -178,12 +179,13 @@ DeviceSpan EmbeddingState::Run(int current_length, DeviceSpan& n return {}; } -DecoderState::DecoderState(const MultiModalLanguageModel& model, DeviceSpan sequence_lengths, const GeneratorParams& params) +DecoderState::DecoderState(const MultiModalLanguageModel& model, DeviceSpan sequence_lengths, + const GeneratorParams& params) : State{params, model}, model_{model}, - position_inputs_{model, *this, sequence_lengths, model_.config_->model.decoder.inputs.attention_mask} { + position_inputs_{CreatePositionInputs(*this, sequence_lengths, model_.config_->model.decoder.inputs.attention_mask)} { inputs_embeds_.Add(); - position_inputs_.Add(); + position_inputs_->Add(); logits_.Add(); kv_cache_.Add(); } @@ -201,7 +203,14 @@ DeviceSpan DecoderState::Run(int current_length, DeviceSpan& nex void DecoderState::UpdateInputsOutputs(DeviceSpan& next_tokens, int total_length, DeviceSpan beam_indices) { int batch_size = static_cast(inputs_embeds_.GetShape()[0]); size_t new_length = next_tokens.size() / batch_size; - position_inputs_.Update(next_tokens, total_length, static_cast(new_length)); + position_inputs_->Update(next_tokens, total_length, static_cast(new_length)); + kv_cache_.Update(beam_indices, total_length); + logits_.Update(next_tokens, new_length); + inputs_embeds_.UpdateSequenceLength(new_length); +} + +// Overload for pipeline to call +void DecoderState::UpdateInputsOutputs(DeviceSpan& next_tokens, int total_length, DeviceSpan beam_indices, size_t new_length) { kv_cache_.Update(beam_indices, total_length); logits_.Update(next_tokens, new_length); inputs_embeds_.UpdateSequenceLength(new_length); @@ -243,6 +252,25 @@ void MultiModalPipelineState::SetExtraInputs(const std::vector& extr speech_state_->SetExtraInputs(extra_inputs, num_audio_tokens_); } embedding_state_->SetExtraInputs(num_images_, num_image_tokens_, num_audio_tokens_); + + // Set the grid tensors for Qwen2-VL if present + if (auto* qwen_pos_inputs = dynamic_cast(decoder_state_->position_inputs_.get())) { + std::shared_ptr img_grid, vid_grid, sec_grid; + + for (const auto& input : extra_inputs) { + if (input.name == Config::Defaults::ImageGridThwName) { + img_grid = input.tensor; + } else if (input.name == "video_grid_thw") { + vid_grid = input.tensor; + } else if (input.name == "second_per_grid_ts") { + sec_grid = input.tensor; + } + } + + if (img_grid || vid_grid) { + qwen_pos_inputs->SetGridTensors(img_grid, vid_grid, sec_grid); + } + } } DeviceSpan MultiModalPipelineState::Run(int current_length, DeviceSpan& next_tokens, DeviceSpan next_indices) { @@ -357,4 +385,4 @@ OrtValue* MultiModalPipelineState::GetOutput(const char* name) { return State::GetOutput(name); }; -} // namespace Generators +} // namespace Generators \ No newline at end of file diff --git a/src/models/multi_modal.h b/src/models/multi_modal.h index 206fc3850b..771f5be36d 100644 --- a/src/models/multi_modal.h +++ b/src/models/multi_modal.h @@ -18,7 +18,7 @@ struct MultiModalLanguageModel : Model { MultiModalLanguageModel(const MultiModalLanguageModel&) = delete; MultiModalLanguageModel& operator=(const MultiModalLanguageModel&) = delete; - std::unique_ptr CreateState(DeviceSpan sequence_lengths, const GeneratorParams& params) const; + std::unique_ptr CreateState(DeviceSpan sequence_lengths, const GeneratorParams& params) const override; std::unique_ptr vision_session_; // pixel_values, [image_attention_mask], image_sizes -> image_features std::unique_ptr speech_session_; // audio_embeds, audio_sizes, audio_projection_mode -> audio_features @@ -96,18 +96,19 @@ struct DecoderState : State { DecoderState& operator=(const DecoderState&) = delete; DeviceSpan Run(int current_length, DeviceSpan& next_tokens, DeviceSpan next_indices) override; + void UpdateInputsOutputs(DeviceSpan& next_tokens, int current_length, DeviceSpan beam_indices); private: friend struct MultiModalPipelineState; - void UpdateInputsOutputs(DeviceSpan& next_tokens, int current_length, DeviceSpan beam_indices); + void UpdateInputsOutputs(DeviceSpan& next_tokens, int current_length, DeviceSpan beam_indices, size_t new_length); const MultiModalLanguageModel& model_; Embeddings inputs_embeds_{*this, Embeddings::Mode::Input, // Model input model_.config_->model.decoder.inputs.embeddings}; - DefaultPositionInputs position_inputs_; // Model input - DefaultKeyValueCache kv_cache_{*this}; // Model input - Logits logits_{*this}; // Model output + std::unique_ptr position_inputs_; // Model input + DefaultKeyValueCache kv_cache_{*this}; // Model input + Logits logits_{*this}; // Model output }; struct MultiModalPipelineState : State { @@ -144,4 +145,4 @@ struct MultiModalPipelineState : State { const std::string speech_adapter_name_{"speech"}; }; -} // namespace Generators +} // namespace Generators \ No newline at end of file diff --git a/src/models/position_inputs.cpp b/src/models/position_inputs.cpp index d87a8c6b64..4e53b27a08 100644 --- a/src/models/position_inputs.cpp +++ b/src/models/position_inputs.cpp @@ -1,9 +1,22 @@ #include "../generators.h" #include "model.h" #include "position_inputs.h" +#include "model_type.h" +#include +#include +#include // For std::round namespace Generators { +// Helper to dispatch type-specific tensor operations +template +void DispatchOnType(ONNXTensorElementDataType type, Func&& func) { + if (type == Ort::TypeToTensorType) + func.template operator()(); + else + func.template operator()(); +} + DefaultPositionInputs::DefaultPositionInputs(const Model& model, State& state, DeviceSpan sequence_lengths_unk, const std::string& attention_mask_name) : model_{model}, state_{state}, @@ -477,7 +490,384 @@ void WindowedPositionInputs::Update(DeviceSpan next_tokens, int total_l window_index_++; } +// Qwen2VLPositionInputs implementation +Qwen2VLPositionInputs::Qwen2VLPositionInputs(const Model& model, State& state, DeviceSpan sequence_lengths_unk) + : model_{model}, + state_{state}, + image_token_id_{model.config_->model.image_token_id}, + video_token_id_{model.config_->model.video_token_id}, + vision_start_token_id_{model.config_->model.vision_start_token_id}, + tokens_per_second_{model.config_->model.vision.tokens_per_second}, + spatial_merge_size_{model.config_->model.vision.spatial_merge_size} { + has_mask_input_ = model_.session_info_.HasInput(model_.config_->model.decoder.inputs.attention_mask); + has_posid_input_ = model_.session_info_.HasInput(model_.config_->model.decoder.inputs.position_ids); + + type_ = Ort::TypeToTensorType; // Default to int64 for Qwen2VL + if (has_mask_input_) { + type_ = model_.session_info_.GetInputDataType(model_.config_->model.decoder.inputs.attention_mask); + } + + if (has_posid_input_) { + ONNXTensorElementDataType posid_type = model_.session_info_.GetInputDataType(model_.config_->model.decoder.inputs.position_ids); + + // Set up 3D position IDs shape: [3, batch_size, sequence_length] + // The 3 dimensions represent temporal, height, and width for mrope + position_ids_shape_[0] = 3; + position_ids_shape_[1] = state_.params_->search.batch_size; + position_ids_shape_[2] = 0; // Will be set during first update + + position_ids_ = std::make_unique(model_.p_device_inputs_, posid_type); + } + if (has_mask_input_) { + attention_mask_shape_[0] = state_.params_->search.batch_size; + attention_mask_shape_[1] = 0; // Will be set during first update + attention_mask_ = std::make_unique(model_.p_device_inputs_, type_); + } +} + +void Qwen2VLPositionInputs::SetGridTensors(const std::shared_ptr& image_grid_thw, + const std::shared_ptr& video_grid_thw, + const std::shared_ptr& second_per_grid_ts) { + image_grid_thw_ = image_grid_thw; + video_grid_thw_ = video_grid_thw; + second_per_grid_ts_ = second_per_grid_ts; +} + +void Qwen2VLPositionInputs::Add() { + if (has_posid_input_) { + AddPositionIDs(); + } + if (has_mask_input_) { + AddAttentionMask(); + } +} + +void Qwen2VLPositionInputs::AddPositionIDs() { + posid_input_index_ = state_.inputs_.size(); + state_.inputs_.push_back(position_ids_->GetOrtTensor()); + state_.input_names_.push_back(model_.config_->model.decoder.inputs.position_ids.c_str()); +} + +void Qwen2VLPositionInputs::AddAttentionMask() { + mask_input_index_ = state_.inputs_.size(); + state_.inputs_.push_back(attention_mask_->GetOrtTensor()); + state_.input_names_.push_back(model_.config_->model.decoder.inputs.attention_mask.c_str()); +} + +template +void Qwen2VLPositionInputs::CreateAndInitialize3DPositionIDs(DeviceSpan next_tokens, std::array shape) { + // Replicates the logic from HuggingFace's `get_rope_index` + // `shape` is [3, batch_size, seq_len] (before beam expansion) + // `next_tokens` is [batch_size, seq_len] + int64_t num_dims = shape[0]; // Should be 3 + int64_t batch_size = shape[1]; + int64_t seq_len = shape[2]; + + auto position_ids = OrtValue::CreateTensor(model_.allocator_cpu_, shape, type_); + auto* position_data = position_ids->GetTensorMutableData(); + + // Get spans for grid_thw tensors (on CPU) + std::span image_grid_thw_span; + if (image_grid_thw_) { + image_grid_thw_span = std::span(image_grid_thw_->GetData(), image_grid_thw_->GetElementCount()); + } + + std::span video_grid_thw_span; + if (video_grid_thw_) { + video_grid_thw_span = std::span(video_grid_thw_->GetData(), video_grid_thw_->GetElementCount()); + } + + std::span second_per_grid_ts_span; + if (second_per_grid_ts_) { + // Qwen 2.5 processor outputs float32 for this + if (second_per_grid_ts_->GetType() != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) + throw std::runtime_error("second_per_grid_ts must be float32."); + second_per_grid_ts_span = std::span(second_per_grid_ts_->GetData(), second_per_grid_ts_->GetElementCount()); + } + + auto input_ids_span = next_tokens.CpuSpan(); + int image_index = 0; + int video_index = 0; + rope_deltas_.clear(); + + for (int64_t b = 0; b < batch_size; ++b) { + auto input_ids = input_ids_span.subspan(b * seq_len, seq_len); + + int64_t image_nums = 0; + int64_t video_nums = 0; + + // Count images/videos for this batch item by checking the token *after* vision_start_token_id + for (size_t s = 0; s < seq_len - 1; ++s) { + if (input_ids[s] == vision_start_token_id_) { + if (input_ids[s + 1] == image_token_id_) { + image_nums++; + } else if (input_ids[s + 1] == video_token_id_) { + video_nums++; + } + } + } + + int64_t st = 0; + int64_t remain_images = image_nums; + int64_t remain_videos = video_nums; + T st_idx = 0; + T max_pos_for_batch = 0; + + for (int64_t k = 0; k < image_nums + video_nums; ++k) { + int64_t ed_image = seq_len + 1; + int64_t ed_video = seq_len + 1; + + // Find next image_token_id (after a vision_start_token_id) + if (remain_images > 0) { + for (int64_t s = st; s < seq_len - 1; ++s) { + if (input_ids[s] == vision_start_token_id_ && input_ids[s + 1] == image_token_id_) { + ed_image = s + 1; // Point to the image_token_id + break; + } + } + } + // Find next video_token_id (after a vision_start_token_id) + if (remain_videos > 0) { + for (int64_t s = st; s < seq_len - 1; ++s) { + if (input_ids[s] == vision_start_token_id_ && input_ids[s + 1] == video_token_id_) { + ed_video = s + 1; // Point to the video_token_id + break; + } + } + } + + int64_t ed; + int64_t t, h, w; + float second_per_grid_t = 0.0f; + + if (ed_image < ed_video) { + // Process image + if (image_index * 3 + 2 >= image_grid_thw_span.size()) + throw std::runtime_error("Not enough image_grid_thw data for image tokens."); + t = image_grid_thw_span[image_index * 3 + 0]; + h = image_grid_thw_span[image_index * 3 + 1]; + w = image_grid_thw_span[image_index * 3 + 2]; + second_per_grid_t = 0.0f; // Images have 0 time delta + image_index++; + remain_images--; + ed = ed_image; + } else { + // Process video + if (video_index * 3 + 2 >= video_grid_thw_span.size()) + throw std::runtime_error("Not enough video_grid_thw data for video tokens."); + t = video_grid_thw_span[video_index * 3 + 0]; + h = video_grid_thw_span[video_index * 3 + 1]; + w = video_grid_thw_span[video_index * 3 + 2]; + if (second_per_grid_ts_span.empty() || video_index >= second_per_grid_ts_span.size()) { + second_per_grid_t = 1.0f; // Default from Python + } else { + second_per_grid_t = second_per_grid_ts_span[video_index]; + } + video_index++; + remain_videos--; + ed = ed_video; + } + + int64_t llm_grid_t = t; + int64_t llm_grid_h = h / spatial_merge_size_; + int64_t llm_grid_w = w / spatial_merge_size_; + + // 1. Fill Text Part + // Text runs from `st` up to `ed-1` (which is the <|vision_start|> token) + int64_t text_len = ed - st; + st_idx = (k > 0 || b > 0) ? max_pos_for_batch + 1 : 0; + T current_pos = st_idx; + + for (int64_t s = 0; s < text_len; ++s) { + int64_t current_token_idx = st + s; + if (input_ids[current_token_idx] == model_.config_->model.pad_token_id) { + position_data[0 * batch_size * seq_len + b * seq_len + current_token_idx] = 0; + position_data[1 * batch_size * seq_len + b * seq_len + current_token_idx] = 0; + position_data[2 * batch_size * seq_len + b * seq_len + current_token_idx] = 0; + } else { + position_data[0 * batch_size * seq_len + b * seq_len + current_token_idx] = current_pos; + position_data[1 * batch_size * seq_len + b * seq_len + current_token_idx] = current_pos; + position_data[2 * batch_size * seq_len + b * seq_len + current_token_idx] = current_pos; + max_pos_for_batch = current_pos; + current_pos++; // Only increment position for non-pad tokens + } + } + + // 2. Fill Vision Part + st_idx = max_pos_for_batch + 1; + int64_t vision_len = llm_grid_t * llm_grid_h * llm_grid_w; + for (int64_t s = 0; s < vision_len; ++s) { + int64_t gt = s / (llm_grid_h * llm_grid_w); + int64_t gh = (s / llm_grid_w) % llm_grid_h; + int64_t gw = s % llm_grid_w; + + // Round to nearest integer for temporal position + // Note: huggingface code use truncation/floor (time_tensor_long = time_tensor.long() when converting time coordinates. + // This will cause slight deviation from the reference during parity comparsion. + T t_pos = static_cast(std::round(gt * second_per_grid_t * tokens_per_second_)) + st_idx; + T h_pos = static_cast(gh) + st_idx; + T w_pos = static_cast(gw) + st_idx; + + // Vision tokens are guaranteed not to be padding + position_data[0 * batch_size * seq_len + b * seq_len + (ed + s)] = t_pos; + position_data[1 * batch_size * seq_len + b * seq_len + (ed + s)] = h_pos; + position_data[2 * batch_size * seq_len + b * seq_len + (ed + s)] = w_pos; + max_pos_for_batch = std::max({max_pos_for_batch, t_pos, h_pos, w_pos}); + } + st = ed + vision_len; // New start is after the vision tokens + } + + // 3. Fill Remaining Text Part + if (st < seq_len) { + st_idx = (max_pos_for_batch == 0 && st == 0) ? 0 : max_pos_for_batch + 1; + int64_t text_len = seq_len - st; + T current_pos = st_idx; + for (int64_t s = 0; s < text_len; ++s) { + int64_t current_token_idx = st + s; + if (input_ids[current_token_idx] == model_.config_->model.pad_token_id) { + position_data[0 * batch_size * seq_len + b * seq_len + current_token_idx] = 0; + position_data[1 * batch_size * seq_len + b * seq_len + current_token_idx] = 0; + position_data[2 * batch_size * seq_len + b * seq_len + current_token_idx] = 0; + } else { + position_data[0 * batch_size * seq_len + b * seq_len + current_token_idx] = current_pos; + position_data[1 * batch_size * seq_len + b * seq_len + current_token_idx] = current_pos; + position_data[2 * batch_size * seq_len + b * seq_len + current_token_idx] = current_pos; + max_pos_for_batch = current_pos; + current_pos++; // Only increment position for non-pad tokens + } + } + } + rope_deltas_.push_back(max_pos_for_batch + 1 - seq_len); + } + + // Move tensor to GPU and expand by num_beams + position_ids_->ort_tensor_ = model_.ExpandInputs(position_ids, state_.params_->search.num_beams); + position_ids_shape_[1] *= state_.params_->search.num_beams; + state_.inputs_[posid_input_index_] = position_ids_->GetOrtTensor(); + + // Expand rope_deltas_ + std::vector expanded_deltas; + for (int64_t delta : rope_deltas_) { + for (int b = 0; b < state_.params_->search.num_beams; ++b) { + expanded_deltas.push_back(delta); + } + } + rope_deltas_ = std::move(expanded_deltas); +} + +template +void Qwen2VLPositionInputs::CreateAndInitializeAttentionMask(DeviceSpan next_tokens, std::array shape) { + auto attention_mask = OrtValue::CreateTensor(model_.allocator_cpu_, shape, type_); + auto* mask_data = attention_mask->GetTensorMutableData(); + auto input_ids_span = next_tokens.CpuSpan(); + int64_t batch_size = shape[0]; + int64_t seq_len = shape[1]; + + for (int64_t b = 0; b < batch_size; ++b) { + for (int64_t s = 0; s < seq_len; ++s) { + int64_t current_token_idx = b * seq_len + s; + mask_data[current_token_idx] = (input_ids_span[current_token_idx] == model_.config_->model.pad_token_id) + ? static_cast(0) + : static_cast(1); + } + } + + // Move tensor to GPU and expand by num_beams + attention_mask_->ort_tensor_ = model_.ExpandInputs(attention_mask, state_.params_->search.num_beams); + attention_mask_shape_[0] *= state_.params_->search.num_beams; + state_.inputs_[mask_input_index_] = attention_mask_->GetOrtTensor(); +} + +void Qwen2VLPositionInputs::Update3DPositionIDs(int base_pos) { + // This is the generation step (decode) + // base_pos is cache_position[0] + auto position_ids = OrtValue::CreateTensor(model_.allocator_cpu_, position_ids_shape_, type_); + int64_t batch_size = position_ids_shape_[1]; // This is already expanded (batch*beams) + int64_t seq_len = position_ids_shape_[2]; // This will be 1 for generation + + if (rope_deltas_.size() != batch_size) { + throw std::runtime_error("rope_deltas size mismatch with batch_size * num_beams."); + } + + DispatchOnType(type_, [&]() { + auto* data = position_ids->GetTensorMutableData(); + for (int64_t dim = 0; dim < 3; ++dim) { + for (int64_t b = 0; b < batch_size; ++b) { + for (int64_t s = 0; s < seq_len; ++s) { + // From Python: delta = (cache_position[0] + self.rope_deltas) + // cache_position[0] is `base_pos`. + T delta = static_cast(base_pos + rope_deltas_[b]); + // Python: position_ids = position_ids + delta + // `position_ids` for new token is just [0, 1, ...] + T pos = static_cast(s); + data[dim * batch_size * seq_len + b * seq_len + s] = delta + pos; + } + } + } + }); + + position_ids_->ort_tensor_ = model_.ExpandInputs(position_ids, 1); // No beam expansion needed, already expanded + state_.inputs_[posid_input_index_] = position_ids_->GetOrtTensor(); +} + +void Qwen2VLPositionInputs::UpdateAttentionMask() { + auto attention_mask = OrtValue::CreateTensor(model_.allocator_cpu_, attention_mask_shape_, type_); + + DispatchOnType(type_, [&]() { + auto* mask_data = attention_mask->GetTensorMutableData(); + std::fill_n(mask_data, attention_mask_shape_[0] * attention_mask_shape_[1], static_cast(1)); + }); + + attention_mask_->ort_tensor_ = model_.ExpandInputs(attention_mask, 1); + state_.inputs_[mask_input_index_] = attention_mask_->GetOrtTensor(); +} + +void Qwen2VLPositionInputs::Update(DeviceSpan next_tokens, int total_length, int new_length) { + if (has_posid_input_) { + position_ids_shape_[2] = new_length; + if (is_first_update_) { + DispatchOnType(type_, [&]() { + CreateAndInitialize3DPositionIDs(next_tokens, position_ids_shape_); + }); + } else { + Update3DPositionIDs(total_length - new_length); + } + } + + if (has_mask_input_) { + if (is_first_update_) { + attention_mask_shape_[1] = new_length; + DispatchOnType(type_, [&]() { + CreateAndInitializeAttentionMask(next_tokens, attention_mask_shape_); + }); + } else { + attention_mask_shape_[1] = total_length; + UpdateAttentionMask(); + } + } + + is_first_update_ = false; +} + +void Qwen2VLPositionInputs::RewindTo(size_t index) { + // For Qwen2-VL, we need to handle rewinding for beam search + // This is a simplified rewind, just updating the shape. + // A full rewind would require re-calculating rope_deltas if we rewound into the prompt. + // For now, we assume rewind only happens during generation. + if (has_posid_input_) { + position_ids_shape_[2] = static_cast(index); + } + if (has_mask_input_) { + attention_mask_shape_[1] = static_cast(index); + } +} + std::unique_ptr CreatePositionInputs(State& state, DeviceSpan sequence_lengths, const std::string& attention_mask_name) { + // Check for Qwen2-VL model type which requires 3D position IDs + if (ModelType::IsQwen2VL(state.model_.config_->model.type)) { + return std::make_unique(state.model_, state, sequence_lengths); + } + if (state.model_.config_->model.decoder.sliding_window.has_value() && state.model_.config_->model.decoder.sliding_window->slide_inputs) { return std::make_unique(state); } else { @@ -485,4 +875,4 @@ std::unique_ptr CreatePositionInputs(State& state, DeviceSpan sequence_lengths_unk); + Qwen2VLPositionInputs(const Qwen2VLPositionInputs&) = delete; + Qwen2VLPositionInputs& operator=(const Qwen2VLPositionInputs&) = delete; + + void Add() override; + void Update(DeviceSpan next_tokens, int total_length, int new_length) override; + void RewindTo(size_t index) override; + + void SetGridTensors(const std::shared_ptr& image_grid_thw, + const std::shared_ptr& video_grid_thw, + const std::shared_ptr& second_per_grid_ts); + + private: + void AddPositionIDs(); + void AddAttentionMask(); + + template + void CreateAndInitialize3DPositionIDs(DeviceSpan next_tokens, std::array shape); + void Update3DPositionIDs(int base_pos); + + template + void CreateAndInitializeAttentionMask(DeviceSpan next_tokens, std::array shape); + void UpdateAttentionMask(); + + const Model& model_; + State& state_; + + size_t mask_input_index_{~0U}; + size_t posid_input_index_{~0U}; + + ONNXTensorElementDataType type_; + + bool has_mask_input_{false}; + bool has_posid_input_{false}; + + std::array position_ids_shape_{}; // {3, batch_size, sequence_length} for 3D positions + std::unique_ptr position_ids_; + + std::array attention_mask_shape_{}; // {batch_size, sequence_length} + std::unique_ptr attention_mask_; + + bool is_first_update_{true}; + + // Cached data from processor + std::shared_ptr image_grid_thw_; + std::shared_ptr video_grid_thw_; + std::shared_ptr second_per_grid_ts_; + std::vector rope_deltas_; + + // Config values initialized from model.config_ in constructor + const int32_t image_token_id_; + const int32_t video_token_id_; + const int32_t vision_start_token_id_; + const float tokens_per_second_; + const int32_t spatial_merge_size_; +}; + std::unique_ptr CreatePositionInputs(State& state, DeviceSpan sequence_lengths, const std::string& attention_mask_name); -} // namespace Generators +} // namespace Generators \ No newline at end of file diff --git a/src/models/qwen_image_processor.cpp b/src/models/qwen_image_processor.cpp new file mode 100644 index 0000000000..32a4ca56e1 --- /dev/null +++ b/src/models/qwen_image_processor.cpp @@ -0,0 +1,306 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "../generators.h" +#include "model.h" + +#include + +namespace Generators { + +namespace { + +// constexpr int64_t kMergeSize = 2; // Qwen2-VL merge size for vision tokens +std::tuple, std::unique_ptr> +ProcessImagePrompt(const Generators::Tokenizer& tokenizer, const std::string& prompt, + OrtxTensor* pixel_values, OrtxTensor* image_grid_thw, + const int64_t* computed_grid_data, int64_t computed_grid_num_images, + Ort::Allocator& allocator, int64_t spatial_merge_size) { + constexpr char vision_start_token[] = "<|vision_start|>"; + constexpr char vision_end_token[] = "<|vision_end|>"; + constexpr char image_pad_token[] = "<|image_pad|>"; + + int64_t num_images = 0; + int64_t total_image_tokens = 0; + const int64_t* image_grid_thw_data = nullptr; + + if (pixel_values) { + const float* pixel_values_data{}; + const int64_t* pixel_values_shape{}; + size_t pixel_values_num_dims; + CheckResult(OrtxGetTensorData(pixel_values, reinterpret_cast(&pixel_values_data), + &pixel_values_shape, &pixel_values_num_dims)); + + // Get image_grid_thw data from either processor output or computed value + if (image_grid_thw) { + const int64_t* image_grid_thw_shape{}; + size_t image_grid_thw_num_dims; + CheckResult(OrtxGetTensorData(image_grid_thw, reinterpret_cast(&image_grid_thw_data), + &image_grid_thw_shape, &image_grid_thw_num_dims)); + num_images = image_grid_thw_shape[0]; + } else if (computed_grid_data) { + image_grid_thw_data = computed_grid_data; + num_images = computed_grid_num_images; + } + + // Calculate total image tokens based on grid dimensions + // For each image: (temporal * height * width) / (merge_size^2) + for (int64_t i = 0; i < num_images; ++i) { + int64_t t = image_grid_thw_data[i * 3 + 0]; + int64_t h = image_grid_thw_data[i * 3 + 1]; + int64_t w = image_grid_thw_data[i * 3 + 2]; + total_image_tokens += (t * h * w) / (spatial_merge_size * spatial_merge_size); + } + } + + // Generate input_ids with vision tokens + std::string text = prompt; + + // If prompt is empty, add vision markers for each image + if (text.empty()) { + for (int64_t i = 0; i < num_images; ++i) { + text += std::string(vision_start_token) + " " + std::string(vision_end_token); + if (i < num_images - 1) { + text += " "; + } + } + } + + // Count the number of vision_start tokens and make sure it matches the number of images + // Need to escape special regex characters in the token + const std::regex vision_start_regex{R"(<\|vision_start\|>)"}; + const auto vision_start_begin = std::sregex_iterator(text.begin(), text.end(), vision_start_regex); + const auto vision_start_end = std::sregex_iterator(); + const auto vision_start_tokens = std::distance(vision_start_begin, vision_start_end); + + if (num_images != vision_start_tokens) { + throw std::runtime_error("Prompt contained " + std::to_string(vision_start_tokens) + + " vision_start tokens but received " + std::to_string(num_images) + " images."); + } + + // For Qwen2-VL, we need to replace vision markers with image_pad tokens + // The number of image_pad tokens for each image depends on the image dimensions + if (num_images > 0 && image_grid_thw_data) { + std::string modified_text; + size_t last_pos = 0; + size_t image_idx = 0; + + std::smatch match; + std::string temp_text = text; + while (std::regex_search(temp_text, match, vision_start_regex)) { + // Add text before the vision_start token + modified_text += text.substr(last_pos, match.position() - (last_pos - (text.size() - temp_text.size()))); + + // Calculate number of image_pad tokens for this image + int64_t t = image_grid_thw_data[image_idx * 3 + 0]; + int64_t h = image_grid_thw_data[image_idx * 3 + 1]; + int64_t w = image_grid_thw_data[image_idx * 3 + 2]; + int64_t num_pads = (t * h * w) / (spatial_merge_size * spatial_merge_size); + + // Add vision_start, image_pad tokens, and vision_end + modified_text += vision_start_token; + for (int64_t i = 0; i < num_pads; ++i) { + modified_text += image_pad_token; + } + modified_text += vision_end_token; + + last_pos = match.position() + match.length() + (text.size() - temp_text.size()); + + // Find and skip vision_end token + size_t vision_end_pos = text.find(vision_end_token, last_pos); + if (vision_end_pos != std::string::npos) { + last_pos = vision_end_pos + strlen(vision_end_token); + } + + temp_text = match.suffix(); + image_idx++; + } + modified_text += text.substr(last_pos); + text = modified_text; + } + + const std::vector input_ids = tokenizer.Encode(text.c_str()); + + std::unique_ptr input_ids_value = OrtValue::CreateTensor( + allocator, std::vector{1, static_cast(input_ids.size())}); + std::copy(input_ids.begin(), input_ids.end(), input_ids_value->GetTensorMutableData()); + + std::unique_ptr num_img_tokens = OrtValue::CreateTensor( + allocator, std::vector{1}); + num_img_tokens->GetTensorMutableData()[0] = total_image_tokens; + + return {std::move(input_ids_value), std::move(num_img_tokens)}; +} + +} // namespace + +QwenImageProcessor::QwenImageProcessor(Config& config, const SessionInfo& session_info) + : pixel_values_type_{session_info.GetInputDataType(config.model.vision.inputs.pixel_values)}, + spatial_merge_size_{config.model.vision.spatial_merge_size} { + const auto processor_config = (config.config_path / fs::path(config.model.vision.config_filename)).string(); + CheckResult(OrtxCreateProcessor(processor_.ToBeAssigned(), processor_config.c_str())); + + config.AddMapping(std::string(Config::Defaults::InputIdsName), config.model.embedding.inputs.input_ids); + config.AddMapping(std::string(Config::Defaults::PixelValuesName), config.model.vision.inputs.pixel_values); +} + +std::unique_ptr QwenImageProcessor::Process(const Tokenizer& tokenizer, const Payload& payload) const { + std::string prompt = std::string(payload.prompt); + const Images* images = payload.images; + Ort::Allocator& allocator{Ort::Allocator::GetWithDefaultOptions()}; + auto named_tensors = std::make_unique(); + + if (!images) { + [[maybe_unused]] auto [input_ids, num_img_tokens] = ProcessImagePrompt(tokenizer, prompt, nullptr, nullptr, nullptr, 0, allocator, spatial_merge_size_); + named_tensors->emplace(Config::Defaults::InputIdsName, std::make_shared(std::move(input_ids))); + return named_tensors; + } + + ort_extensions::OrtxObjectPtr result; + CheckResult(OrtxImagePreProcess(processor_.get(), images->images_.get(), result.ToBeAssigned())); + + OrtxTensor* pixel_values = nullptr; + CheckResult(OrtxTensorResultGetAt(result.get(), 0, &pixel_values)); + + OrtxTensor* image_grid_thw = nullptr; + // Try to get image_grid_thw from processor (second output) + auto status = OrtxTensorResultGetAt(result.get(), 1, &image_grid_thw); + + // Get pixel_values data and shape + const float* pixel_values_data{}; + const int64_t* pixel_values_shape{}; + size_t pixel_values_num_dims; + CheckResult(OrtxGetTensorData(pixel_values, reinterpret_cast(&pixel_values_data), + &pixel_values_shape, &pixel_values_num_dims)); + + // If processor doesn't provide image_grid_thw or patched pixel_values, compute them + std::unique_ptr computed_image_grid_thw; + std::unique_ptr patched_pixel_values; + const int64_t* computed_grid_data = nullptr; + int64_t computed_grid_num_images = 0; + + // Check if pixel_values needs patching (shape should be [1, height, width, channels] in HWC format) + if (pixel_values_num_dims == 4 && pixel_values_shape[0] == 1) { + constexpr int64_t kPatchSize = 14; + constexpr int64_t kTemporalPatchSize = 2; + constexpr int64_t kChannels = 3; + + int64_t height = pixel_values_shape[1]; // HWC: [batch, height, width, channels] + int64_t width = pixel_values_shape[2]; + int64_t channels = pixel_values_shape[3]; + + int64_t height_patches = height / kPatchSize; + int64_t width_patches = width / kPatchSize; + int64_t total_patches = height_patches * width_patches; + int64_t patch_dim = channels * kTemporalPatchSize * kPatchSize * kPatchSize; + + // Create patched pixel_values: [total_patches, patch_dim] + patched_pixel_values = OrtValue::CreateTensor( + allocator, std::vector{total_patches, patch_dim}); + auto* patched_data = patched_pixel_values->GetTensorMutableData(); + + // Extract patches from single image in HWC format + // Each spatial patch is replicated kTemporalPatchSize times + int64_t patch_idx = 0; + for (int64_t ph = 0; ph < height_patches; ++ph) { + for (int64_t pw = 0; pw < width_patches; ++pw) { + int64_t h_start = ph * kPatchSize; + int64_t w_start = pw * kPatchSize; + + int64_t write_idx = patch_idx * patch_dim; + + // Repeat the same spatial patch kTemporalPatchSize times + // Output: [temporal, channels, patch_h, patch_w] + for (int64_t t = 0; t < kTemporalPatchSize; ++t) { + for (int64_t c = 0; c < channels; ++c) { + for (int64_t h = 0; h < kPatchSize; ++h) { + for (int64_t w = 0; w < kPatchSize; ++w) { + // HWC format: pixel_values[height][width][channels] + int64_t src_idx = (h_start + h) * width * channels + (w_start + w) * channels + c; + patched_data[write_idx++] = pixel_values_data[src_idx]; + } + } + } + } + patch_idx++; + } + } + + // Create image_grid_thw: [1, 3] for single image + if (status != kOrtxOK || !image_grid_thw) { + computed_image_grid_thw = OrtValue::CreateTensor( + allocator, std::vector{1, 3}); + auto* grid_data = computed_image_grid_thw->GetTensorMutableData(); + + // For a single image: T=1 (one frame), H=height_patches, W=width_patches + // The kTemporalPatchSize is embedded in the patch dimension + grid_data[0] = 1; // Single temporal frame for images + grid_data[1] = height_patches; + grid_data[2] = width_patches; + + computed_grid_data = grid_data; + computed_grid_num_images = 1; + } + } + + auto [input_ids, num_img_tokens] = ProcessImagePrompt(tokenizer, prompt, pixel_values, + image_grid_thw, computed_grid_data, computed_grid_num_images, allocator, spatial_merge_size_); + named_tensors->emplace(std::string(Config::Defaults::InputIdsName), std::make_shared(std::move(input_ids))); + + // Use patched pixel_values if we computed it, otherwise use processor output + if (patched_pixel_values) { + // Convert to the correct type if needed + if (pixel_values_type_ == ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16) { + // Convert float to bfloat16 + auto shape_vec = patched_pixel_values->GetTensorTypeAndShapeInfo()->GetShape(); + auto bf16_tensor = OrtValue::CreateTensor(allocator, shape_vec); + const float* src = patched_pixel_values->GetTensorData(); + auto* dst = static_cast(bf16_tensor->GetTensorMutableData()); + size_t count = patched_pixel_values->GetTensorTypeAndShapeInfo()->GetElementCount(); + for (size_t i = 0; i < count; ++i) { + dst[i] = Float32ToBFloat16(src[i]); + } + named_tensors->emplace(std::string(Config::Defaults::PixelValuesName), + std::make_shared(std::move(bf16_tensor))); + } else if (pixel_values_type_ == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) { + // Convert float to float16 + auto shape_vec = patched_pixel_values->GetTensorTypeAndShapeInfo()->GetShape(); + auto fp16_tensor = OrtValue::CreateTensor(allocator, shape_vec); + const float* src = patched_pixel_values->GetTensorData(); + auto* dst = static_cast(fp16_tensor->GetTensorMutableData()); + size_t count = patched_pixel_values->GetTensorTypeAndShapeInfo()->GetElementCount(); + for (size_t i = 0; i < count; ++i) { + dst[i] = FastFloat32ToFloat16(src[i]); + } + named_tensors->emplace(std::string(Config::Defaults::PixelValuesName), + std::make_shared(std::move(fp16_tensor))); + } else { + named_tensors->emplace(std::string(Config::Defaults::PixelValuesName), + std::make_shared(std::move(patched_pixel_values))); + } + } else if (pixel_values_type_ == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) { + named_tensors->emplace(std::string(Config::Defaults::PixelValuesName), + std::make_shared(ProcessTensor(pixel_values, allocator))); + } else if (pixel_values_type_ == ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16) { + named_tensors->emplace(std::string(Config::Defaults::PixelValuesName), + std::make_shared(ProcessTensor(pixel_values, allocator))); + } else { + named_tensors->emplace(std::string(Config::Defaults::PixelValuesName), + std::make_shared(ProcessTensor(pixel_values, allocator))); + } + + // Add image_grid_thw tensor (either from processor or computed) + if (image_grid_thw) { + named_tensors->emplace("image_grid_thw", + std::make_shared(ProcessTensor(image_grid_thw, allocator))); + } else if (computed_image_grid_thw) { + named_tensors->emplace("image_grid_thw", + std::make_shared(std::move(computed_image_grid_thw))); + } + + named_tensors->emplace(std::string(Config::Defaults::NumImageTokens), std::make_shared(std::move(num_img_tokens))); + + return named_tensors; +} + +} // namespace Generators diff --git a/src/models/qwen_image_processor.h b/src/models/qwen_image_processor.h new file mode 100644 index 0000000000..ce1ba26f0b --- /dev/null +++ b/src/models/qwen_image_processor.h @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include "processor.h" + +namespace Generators { + +struct QwenImageProcessor : Processor { + QwenImageProcessor(Config& config, const SessionInfo& session_info); + + virtual std::unique_ptr Process(const Tokenizer& tokenizer, const Payload& payload) const override; + + private: + ort_extensions::OrtxObjectPtr processor_; + + ONNXTensorElementDataType pixel_values_type_; + int64_t spatial_merge_size_; +}; + +} // namespace Generators \ No newline at end of file