diff --git a/paddlemix/examples/mPLUG_Owl3/README.md b/paddlemix/examples/mPLUG_Owl3/README.md new file mode 100644 index 000000000..baababe99 --- /dev/null +++ b/paddlemix/examples/mPLUG_Owl3/README.md @@ -0,0 +1,42 @@ +# mPLUG-Owl3 + +## 1. 模型介绍 + +**本仓库支持的模型权重:** + +| Model | +|--------------------| +| mPLUG/mPLUG-Owl3-7B-241101 | + +注意:与huggingface权重同名,但权重为paddle框架的Tensor,使用`xxx.from_pretrained("mPLUG/mPLUG-Owl3-7B-241101")`即可自动下载该权重文件夹到缓存目录。 + + +## 2 环境准备 + +1)[安装 PaddleMIX 环境依赖包](https://github.com/PaddlePaddle/PaddleMIX/tree/develop?tab=readme-ov-file#%E5%AE%89%E8%A3%85) + +2)pip install pillow tqdm paddlenlp==3.0.0b2 + +注意:Python版本最好为3.10及以上版本。 + +## 3 快速开始 + +### 推理 +```bash +# 图片理解 +CUDA_VISIBLE_DEVICES=0 python paddlemix/examples/mPLUG_Owl3/run_inference.py \ +``` + + +### 参考文献 +```BibTeX +@misc{ye2024mplugowl3longimagesequenceunderstanding, + title={mPLUG-Owl3: Towards Long Image-Sequence Understanding in Multi-Modal Large Language Models}, + author={Jiabo Ye and Haiyang Xu and Haowei Liu and Anwen Hu and Ming Yan and Qi Qian and Ji Zhang and Fei Huang and Jingren Zhou}, + year={2024}, + eprint={2408.04840}, + archivePrefix={arXiv}, + primaryClass={cs.CV}, + url={https://arxiv.org/abs/2408.04840}, +} +``` diff --git a/paddlemix/examples/mPLUG_Owl3/requirement.txt b/paddlemix/examples/mPLUG_Owl3/requirement.txt new file mode 100644 index 000000000..c1cc9aebb --- /dev/null +++ b/paddlemix/examples/mPLUG_Owl3/requirement.txt @@ -0,0 +1,3 @@ +pillow +tqdm +paddlenlp==3.0.0b2 \ No newline at end of file diff --git a/paddlemix/examples/mPLUG_Owl3/run_inference.py b/paddlemix/examples/mPLUG_Owl3/run_inference.py new file mode 100644 index 000000000..8db4d537d --- /dev/null +++ b/paddlemix/examples/mPLUG_Owl3/run_inference.py @@ -0,0 +1,46 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddlenlp.transformers import Qwen2Tokenizer +from PIL import Image + +from paddlemix.models.mPLUGOwl3.configuration_mplugowl3 import mPLUGOwl3Config +from paddlemix.models.mPLUGOwl3.modeling_mplugowl3 import mPLUGOwl3Model + +model_path = "mPLUG/mPLUG-Owl3-7B-241101" + +config = mPLUGOwl3Config.from_pretrained(model_path) +model = mPLUGOwl3Model.from_pretrained(model_path, dtype=paddle.bfloat16).eval() +tokenizer = Qwen2Tokenizer.from_pretrained(model_path) +processor = model.init_processor(tokenizer) + +# image = Image.new('RGB', (500, 500), color='red') +image = Image.open("paddlemix/demo_images/examples_image1.jpg").convert("RGB") + +messages = [{"role": "user", "content": """<|image|>Describe this image."""}, {"role": "assistant", "content": ""}] + +inputs = processor(messages, images=[image], videos=None) +inputs["pixel_values"] = inputs["pixel_values"].cast(paddle.bfloat16) + +inputs.update( + { + "tokenizer": tokenizer, + "max_new_tokens": 512, # + "decode_text": True, + } +) + +res = model.generate(**inputs) +print("output:\n", res) diff --git a/paddlemix/models/mPLUGOwl3/__init__.py b/paddlemix/models/mPLUGOwl3/__init__.py new file mode 100644 index 000000000..a9bd46569 --- /dev/null +++ b/paddlemix/models/mPLUGOwl3/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .configuration_hyper_qwen2 import * +from .configuration_mplugowl3 import * +from .modeling_hyper_qwen2 import * +from .modeling_mplugowl3 import * +from .modeling_navit_siglip import * diff --git a/paddlemix/models/mPLUGOwl3/configuration_hyper_qwen2.py b/paddlemix/models/mPLUGOwl3/configuration_hyper_qwen2.py new file mode 100644 index 000000000..d0057b667 --- /dev/null +++ b/paddlemix/models/mPLUGOwl3/configuration_hyper_qwen2.py @@ -0,0 +1,141 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddlenlp.transformers import PretrainedConfig + + +class HyperQwen2Config(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a + Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 151936): + Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Qwen2Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 22016): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 32): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + use_sliding_window (`bool`, *optional*, defaults to `False`): + Whether to use sliding window attention. + sliding_window (`int`, *optional*, defaults to 4096): + Sliding window attention (SWA) window size. If not specified, will default to `4096`. + max_window_layers (`int`, *optional*, defaults to 28): + The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + ```python + >>> from transformers import Qwen2Model, Qwen2Config + + >>> # Initializing a Qwen2 style configuration + >>> configuration = Qwen2Config() + + >>> # Initializing a model from the Qwen2-7B style configuration + >>> model = Qwen2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen2" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=151936, + hidden_size=4096, + intermediate_size=22016, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=32, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=28, + attention_dropout=0.0, + hyper_layers=[1,9,17,25], + vision_batch_size=16, + rope_scaling=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window if use_sliding_window else None + self.max_window_layers = max_window_layers + self.rope_scaling = rope_scaling + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + self.hyper_layers = hyper_layers + self.vision_batch_size = vision_batch_size + self.seq_length = 1 #self.max_length ### + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/paddlemix/models/mPLUGOwl3/configuration_mplugowl3.py b/paddlemix/models/mPLUGOwl3/configuration_mplugowl3.py new file mode 100644 index 000000000..f31fa6049 --- /dev/null +++ b/paddlemix/models/mPLUGOwl3/configuration_mplugowl3.py @@ -0,0 +1,54 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddlemix.utils.log import logger + +from .configuration_hyper_qwen2 import HyperQwen2Config +from .modeling_navit_siglip import SigLipVisionConfig + + +class mPLUGOwl3Config(HyperQwen2Config): + model_type = "mplugowl3" + keys_to_ignore_at_inference = ["past_key_values"] + + default_vision_config = { + "hidden_size": 1152, + "image_size": 378, + "intermediate_size": 4304, + "model_type": "siglip_vision_model", + "num_attention_heads": 16, + "num_hidden_layers": 27, + "patch_size": 14, + } + + def __init__( + self, + use_cache=True, + vision_config=None, + **kwargs, + ): + self.use_cache = use_cache + + # same as HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit add tgt_sizes + if vision_config is None: + self.vision_config = SigLipVisionConfig(**self.default_vision_config) + logger.info("vision_config is None, using default vision config") + elif isinstance(vision_config, dict): + self.vision_config = SigLipVisionConfig(**vision_config) + elif isinstance(vision_config, SigLipVisionConfig): + self.vision_config = vision_config + self.image_size = 378 + self.patch_size = self.vision_config.patch_size + + super().__init__(**kwargs) diff --git a/paddlemix/models/mPLUGOwl3/modeling_hyper_qwen2.py b/paddlemix/models/mPLUGOwl3/modeling_hyper_qwen2.py new file mode 100644 index 000000000..104a9bf23 --- /dev/null +++ b/paddlemix/models/mPLUGOwl3/modeling_hyper_qwen2.py @@ -0,0 +1,983 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple, Union + +import paddle +import paddle.nn as nn +import paddlenlp +from einops import rearrange, repeat +from paddle.nn import MultiHeadAttention +from paddlenlp.transformers.qwen2.modeling import Qwen2Attention + +from paddlemix.utils.log import logger + +from ...activations import ACT2FN +from .configuration_hyper_qwen2 import HyperQwen2Config + + +def is_casual_mask(attention_mask): + """ + Upper triangular of attention_mask equals to attention_mask is casual + """ + return (paddle.triu(attention_mask) == attention_mask).all().item() + + +def _make_causal_mask(input_ids_shape, past_key_values_length): + """ + Make causal mask used for self-attention + """ + batch_size, target_length = input_ids_shape # target_length: seq_len + + mask = paddle.tril(paddle.ones((target_length, target_length), dtype="bool")) + + if past_key_values_length > 0: + # [tgt_len, tgt_len + past_len] + mask = paddle.concat([paddle.ones([target_length, past_key_values_length], dtype="bool"), mask], axis=-1) + + # [bs, 1, tgt_len, tgt_len + past_len] + return mask[None, None, :, :].expand([batch_size, 1, target_length, target_length + past_key_values_length]) + + +def _expand_2d_mask(mask, dtype, tgt_length): + """ + Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`. + """ + batch_size, src_length = mask.shape[0], mask.shape[-1] + tgt_length = tgt_length if tgt_length is not None else src_length + + mask = mask[:, None, None, :].astype("bool") + mask.stop_gradient = True + expanded_mask = mask.expand([batch_size, 1, tgt_length, src_length]) + + return expanded_mask + + +class Qwen2RMSNorm(paddle.nn.Layer): + def __init__(self, hidden_size, eps=1e-06): + """ + Qwen2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = paddle.base.framework.EagerParamBase.from_tensor(tensor=paddle.ones(shape=hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to("float32") + variance = hidden_states.pow(y=2).mean(axis=-1, keepdim=True) + hidden_states = hidden_states * paddle.rsqrt(x=variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class Qwen2RotaryEmbedding(nn.Layer): + def __init__(self, dim, max_position_embeddings=2048, base=10000): + super().__init__() + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / self.base ** ( + paddle.arange(start=0, end=self.dim, step=2, dtype="int64").astype(dtype="float32") / self.dim + ) + self.register_buffer(name="inv_freq", tensor=inv_freq, persistable=False) + self._set_cos_sin_cache(seq_len=max_position_embeddings, dtype=paddle.get_default_dtype()) + + def _set_cos_sin_cache(self, seq_len, dtype): + self.max_seq_len_cached = seq_len + t = paddle.arange(dtype="int64", end=self.max_seq_len_cached).astype(dtype=self.inv_freq.dtype) + freqs = paddle.outer(x=t, y=self.inv_freq) + emb = paddle.concat(x=(freqs, freqs), axis=-1) + self.register_buffer(name="cos_cached", tensor=emb.cos().to(dtype), persistable=False) + self.register_buffer(name="sin_cached", tensor=emb.sin().to(dtype), persistable=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +class RotaryEmbedding(paddle.nn.Layer): + def __init__(self, dim, base=10000, use_fp32=False, use_outer_in_rope=False): + super().__init__() + self.dim = dim + self.base = base + self.use_fp32 = use_fp32 + if use_fp32: + self.inv_freq = 1.0 / base ** (paddle.arange(start=0, end=dim, step=2).astype(dtype="float32") / dim) + else: + inv_freq = 1.0 / base ** (paddle.arange(start=0, end=dim, step=2).astype(dtype="float32") / dim) + self.register_buffer(name="inv_freq", tensor=inv_freq) + + self._rotary_pos_emb_cache = None + self._seq_len_cached = 0 + self.use_outer_in_rope = use_outer_in_rope + self._ntk_alpha_cached = 1.0 + + def update_rotary_pos_emb_cache(self, max_seq_len, offset=0, ntk_alpha=1.0): + seqlen = max_seq_len + offset + if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached: + base = self.base * ntk_alpha ** (self.dim / (self.dim - 2)) + self.inv_freq = 1.0 / base ** ( + paddle.arange(start=0, end=self.dim, step=2).astype(dtype="float32") / self.dim + ) + self._seq_len_cached = seqlen + self._ntk_alpha_cached = ntk_alpha + seq = paddle.arange(end=seqlen) + if 1: # self.use_outer_in_rope: + freqs = paddle.outer(x=seq.astype(dtype=self.inv_freq.dtype), y=self.inv_freq) + # else: + # freqs = einsum("i , j -> i j", seq.astype(dtype=self.inv_freq.dtype), self.inv_freq) + emb = paddle.concat(x=(freqs, freqs), axis=-1) + # emb [seq_length, .., dim] + self._rotary_pos_emb_cache = rearrange(emb, "n d -> n 1 1 d") + + def forward(self, max_seq_len, offset=0, ntk_alpha=1.0): + self.update_rotary_pos_emb_cache(max_seq_len, offset, ntk_alpha) + return self._rotary_pos_emb_cache[offset : offset + max_seq_len] + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : tuple(x.shape)[-1] // 2] + x2 = x[..., tuple(x.shape)[-1] // 2 :] + return paddle.concat(x=(-x2, x1), axis=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class Qwen2MLP(nn.Layer): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias_attr=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +def repeat_kv(hidden_states: paddle.Tensor, n_rep: int) -> paddle.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = tuple(hidden_states.shape) + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(shape=[batch, num_key_value_heads, n_rep, slen, head_dim]) + return hidden_states.reshape([batch, num_key_value_heads * n_rep, slen, head_dim]) + + +def _rotate_half(x): + """ + change sign so the last dimension becomes [-odd, +even] + """ + x = rearrange(x, "... (j d) -> ... j d", j=2) + x1, x2 = x.unbind(axis=-2) + return paddle.concat(x=(-x2, x1), axis=-1) + + +def apply_rotary_pos_emb_core(t, freqs, use_fp32=False, debug=False): + """ + input tensor t is of shape [seq_length, ..., dim] + rotary positional embeding tensor freqs is of shape [seq_length, ..., dim] + check https://kexue.fm/archives/8265 for detailed formulas + """ + # if use_flash_rotary and use_fp32: + # t_ = rearrange(t, "s b ... -> b s ...") + # if use_fp32: + # t_ = t_.astype(dtype="float32") + # freqs = freqs.squeeze(axis=1).squeeze(axis=1) + # cos = freqs[:, :freqs.shape[-1] // 2].cos() + # sin = freqs[:, :freqs.shape[-1] // 2].sin() + # output = apply_rotary_emb_func(t_, cos, sin).astype(dtype=t.dtype) # TODO + # return rearrange(output, 'b s ... -> s b ...') + + rot_dim = freqs.shape[-1] + # ideally t_pass is empty so rotary pos embedding is applied to all tensor t + t_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:] + + if use_fp32: + t_ = t_.astype(dtype="float32") + t_pass_ = t_pass_.astype(dtype="float32") + # first part is cosine component + # second part is sine component, need to change signs with _rotate_half method + t_ = (t_ * freqs.cos()) + (_rotate_half(t_) * freqs.sin()) + return paddle.concat(x=(t_, t_pass_), axis=-1).astype(dtype=t.dtype) + + +class HyperQwen2Attention(nn.Layer): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: HyperQwen2Config, layer_idx: Optional[int] = None, is_hyper_enabled=False): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.attention_dropout = config.attention_dropout + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias_attr=True) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias_attr=True) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias_attr=True) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias_attr=False) + + self.rotary_emb = Qwen2RotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + self.rotary_emb_core = RotaryEmbedding( + self.head_dim, base=self.rope_theta, use_fp32=True, use_outer_in_rope=True + ) + # Hyper Attention Modules + self.is_hyper_enabled = is_hyper_enabled + if self.is_hyper_enabled: + self.v_kv_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim * 2, bias_attr=True) + + self.visual_cache = {} + + self.use_flexattention = True + + def apply_mi_rope(self, key_layer, image_pos, length_each_img): + # input shape should be [s b h d] + key_layer = rearrange(key_layer, "b h s d -> s b h d") + rotary_pos_emb_max_seq_len = self.config.max_position_embeddings + ntk_alpha = 1 + rotary_pos_emb = self.rotary_emb_core(rotary_pos_emb_max_seq_len, ntk_alpha=ntk_alpha) + assert rotary_pos_emb is not None + + if isinstance(rotary_pos_emb, tuple): + rotary_pos_emb = rotary_pos_emb + else: + rotary_pos_emb = (rotary_pos_emb,) * 2 + + if rotary_pos_emb is not None: + q_pos_emb, k_pos_emb = rotary_pos_emb + + k_pos_emb = repeat(k_pos_emb[image_pos], "N_img b h d -> (N_img L) b h d", L=length_each_img) # N_img, dim + + key_layer = apply_rotary_pos_emb_core(key_layer, k_pos_emb, use_fp32=True) # TODO difference + key_layer = rearrange(key_layer, "s b h d -> b h s d") + return key_layer + + +class HyperQwen2SdpaAttention(HyperQwen2Attention): + """ + Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + def hyperattention( + self, + hidden_states: paddle.Tensor, + attention_mask: Optional[paddle.Tensor] = None, + position_ids: Optional[paddle.Tensor] = None, + image_embeds=None, + media_offset=None, + past_key_value: Optional[MultiHeadAttention.Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]: + bsz, q_len, _ = hidden_states.shape # (1, 74, 28, 128) bsz, q_len, self.num_heads, self.head_dim + + try: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + except: + hidden_states = hidden_states.astype("bfloat16") + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.reshape([bsz, q_len, self.num_heads, self.head_dim]).transpose([0, 2, 1, 3]) + key_states = key_states.reshape([bsz, q_len, self.num_key_value_heads, self.head_dim]).transpose([0, 2, 1, 3]) + value_states = value_states.reshape([bsz, q_len, self.num_key_value_heads, self.head_dim]).transpose( + [0, 2, 1, 3] + ) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # [1, 28, 1, 128] [1, 4, 1, 128] + + if past_key_value is not None: + key_states = paddle.concat([past_key_value[0], key_states], axis=2) + value_states = paddle.concat([past_key_value[1], value_states], axis=2) + past_key_value = (key_states, value_states) if use_cache else None + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # add visual to kv + length_each_img = image_embeds.shape[1] + try: + image_embeds = self.v_kv_proj(image_embeds) + except: + image_embeds = self.v_kv_proj(image_embeds.astype("bfloat16")) + image_start = 0 + context_layer = [] + for bi, media_starts in enumerate(media_offset): + num_images = media_starts.shape[0] + if num_images > 0: + if q_len == 1: + full_mask = paddle.ones((1, 1, 1, num_images * length_each_img + kv_seq_len)).astype(paddle.bool) + else: + causal_mask = paddle.tril(paddle.ones([q_len, kv_seq_len])).astype(paddle.bool) + # 扩展维度以匹配 (bsz, 1, q_len, kv_seq_len) + causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) + + matrix = paddle.arange(q_len).reshape([-1, 1]) + t2vmask = ~(matrix < media_starts.reshape([1, -1])) + t2vmask = repeat(t2vmask, "seq_t seq_v -> 1 1 seq_t (seq_v v_token)", v_token=length_each_img) + full_mask = paddle.concat( + [t2vmask, causal_mask], axis=3 + ) # unsqueeze batch dim (batch, 1, seq_q, seq_k) + + curr_query_layer = query_states[bi : bi + 1] + # order is sbhd + curr_visual_key_layer, curr_visual_value_layer = rearrange( + image_embeds[image_start : image_start + num_images], + "BL Lv (H KV D) -> KV 1 H (BL Lv) D", + KV=2, + H=self.num_key_value_heads, + ) # b h s d + image_start += num_images + + curr_visual_key_layer = self.apply_mi_rope( + curr_visual_key_layer, media_starts, length_each_img=length_each_img + ) + + curr_visual_key_layer = repeat_kv(curr_visual_key_layer, self.num_key_value_groups) + curr_visual_value_layer = repeat_kv(curr_visual_value_layer, self.num_key_value_groups) + + curr_key_layer = paddle.concat([curr_visual_key_layer, key_states[bi : bi + 1]], axis=2) + curr_value_layer = paddle.concat([curr_visual_value_layer, value_states[bi : bi + 1]], axis=2) + is_causal = False + else: + # 执行无图attention + curr_query_layer = query_states[bi : bi + 1] + curr_key_layer = key_states[bi : bi + 1] + curr_value_layer = value_states[bi : bi + 1] + is_causal = True if q_len > 1 else False + if is_causal: + full_mask = None + else: + causal_mask = paddle.tril(paddle.ones([q_len, kv_seq_len])).astype(paddle.bool) + # 扩展维度以匹配 (bsz, 1, q_len, kv_seq_len) + causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) + full_mask = causal_mask + + # Note: 注意paddle的scaled_dot_product_attention 中q k v维度与torch不同 + attn_output = paddle.nn.functional.scaled_dot_product_attention( + curr_query_layer.transpose( + [0, 2, 1, 3] + ), # (batch, ..., sequence, dim) # [1, 74, 28, 128], torch [1, 28, 74, 128] sum 18304. + curr_key_layer.transpose( + [0, 2, 1, 3] + ), # [1, 5177, 28, 128], torch [1, 28, 5177, 128] sum 1044480 mean 0.05615234 torch sum 1036288. mean 0.0559 + curr_value_layer.transpose([0, 2, 1, 3]), # [1, 5177, 28, 128] , torch [1, 28, 5177, 128] sum -158720 + attn_mask=full_mask.cast( + curr_query_layer.dtype + ), # (N, ..., L, S) A boolean mask where a value of True indicates that the element *should* take part in attention. + dropout_p=self.attention_dropout if self.training else 0.0, + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal=is_causal, + ) # -> (N, ..., L, Ev) + # torch attn_output.shape [1, 28, 72, 128] + attn_output = attn_output.transpose([0, 2, 1, 3]) + assert attn_output.shape[0] == 1 + context_layer.append(attn_output) + attn_output = context_layer = paddle.concat(context_layer, axis=0) + + attn_output = attn_output.transpose([0, 2, 1, 3]) + attn_output = attn_output.reshape([bsz, q_len, self.hidden_size]) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + # Adapted from Qwen2Attention.forward + def forward( + self, + hidden_states: paddle.Tensor, + attention_mask: Optional[paddle.Tensor] = None, + position_ids: Optional[paddle.Tensor] = None, + image_embeds=None, + media_offset=None, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]: + # TODO: + # if output_attentions: + # # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + # logger.warning_once( + # "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + # 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + # ) + # return super().forward( + # hidden_states=hidden_states, + # attention_mask=attention_mask, + # position_ids=position_ids, + # past_key_value=past_key_value, + # output_attentions=output_attentions, + # use_cache=use_cache, + # ) + + if self.is_hyper_enabled and image_embeds is not None: + return self.hyperattention( + hidden_states, + attention_mask, + position_ids, + image_embeds, + media_offset, + past_key_value, + output_attentions, + use_cache, + ) + + bsz, q_len, _ = hidden_states.shape + + try: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + except: + hidden_states = hidden_states.astype("bfloat16") + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.reshape([bsz, q_len, self.num_heads, self.head_dim]).transpose([0, 2, 1, 3]) + key_states = key_states.reshape([bsz, q_len, self.num_key_value_heads, self.head_dim]).transpose([0, 2, 1, 3]) + value_states = value_states.reshape([bsz, q_len, self.num_key_value_heads, self.head_dim]).transpose( + [0, 2, 1, 3] + ) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + key_states = paddle.concat([past_key_value[0], key_states], axis=2) + value_states = paddle.concat([past_key_value[1], value_states], axis=2) + past_key_value = (key_states, value_states) if use_cache else None + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if attention_mask is not None: + if tuple(attention_mask.shape) != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {bsz, 1, q_len, kv_seq_len}, but is {tuple(attention_mask.shape)}" + ) + + # Note: 注意paddle的scaled_dot_product_attention 中q k v维度与torch不同 + attn_output = paddle.nn.functional.scaled_dot_product_attention( + query_states.transpose([0, 2, 1, 3]), # [1, 28, 74, 128] sum 21632. + key_states.transpose([0, 2, 1, 3]), # [1, 28, 74, 128] sum 335872. + value_states.transpose([0, 2, 1, 3]), # [1, 28, 74, 128] sum 1680. + attn_mask=attention_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal=self.is_causal and attention_mask is None and q_len > 1, + ) + # [1, 74, 28, 128] sum 1408. + attn_output = attn_output.reshape([bsz, q_len, self.hidden_size]) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +# Original Attention of Qwen2 +# PaddleNLP only has Qwen2Attention +QWEN2_ATTENTION_CLASSES = { + "eager": Qwen2Attention, + "flash_attention_2": Qwen2Attention, # Qwen2FlashAttention2, + "sdpa": Qwen2Attention, # Qwen2SdpaAttention, +} + + +class HyperQwen2DecoderLayer(nn.Layer): + def __init__(self, config: HyperQwen2Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.is_hyper_enabled = (layer_idx + 1) in config.hyper_layers + # TODO: 若使用Qwen2Attention则回答结果不对,若都使用HyperQwen2SdpaAttention回答结果也对,但需check一下 + if 1: # self.is_hyper_enabled: + self.self_attn = HyperQwen2SdpaAttention(config, layer_idx, is_hyper_enabled=self.is_hyper_enabled) + else: + # self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + self.self_attn = QWEN2_ATTENTION_CLASSES["flash_attention_2"](config, layer_idx) + + self.mlp = Qwen2MLP(config) + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: paddle.Tensor, + attention_mask: Optional[paddle.Tensor] = None, + position_ids: Optional[paddle.Tensor] = None, + image_embeds=None, + media_offset=None, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]: + """ + Args: + hidden_states (`paddle.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`paddle.Tensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(paddle.Tensor)`, *optional*): cached past key and value projection states + """ + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + # Shared LayerNorm + if image_embeds is not None and self.is_hyper_enabled: + image_embeds = self.input_layernorm(image_embeds) + media_kwargs = {"image_embeds": image_embeds, "media_offset": media_offset} + else: + image_embeds = media_offset = None + media_kwargs = {} + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( # -704. 2080. (48128., 240.) + hidden_states=hidden_states.cast(paddle.bfloat16), # [1, 74, 3584] sum -704. + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=True, # TODO, paddlenlp默认是False,但是不返回self_attn_weights。output_attentions全局是false,这里改成True是无影响的 + use_cache=use_cache, + **media_kwargs, # {} + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + try: + hidden_states = self.mlp(hidden_states.cast(paddle.bfloat16)) + except: + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class Qwen2PreTrainedModel(paddlenlp.transformers.model_utils.PretrainedModel): + config_class = HyperQwen2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["HyperQwen2DecoderLayer"] + _skip_keys_device_placement = "past_key_values" + # _supports_flash_attn_2 = True + # _supports_sdpa = True + _supports_cache_class = True + + def _init_weights(self, layer): + std = self.config.initializer_range + if isinstance(layer, (paddle.nn.Linear, paddle.nn.Conv3D)): + paddle.nn.initializer.Normal(mean=0.0, std=std)(layer.weight) + if layer.bias is not None: + paddle.nn.initializer.Constant(0.0)(layer.bias) + elif isinstance(layer, paddle.nn.Embedding): + paddle.nn.initializer.Normal(mean=0.0, std=std)(layer.weight) + if layer._padding_idx is not None: + with paddle.no_grad(): + layer.weight[layer._padding_idx] = 0.0 + + +class HyperQwen2Model(Qwen2PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] + + Args: + config: HyperQwen2Config + """ + + def __init__(self, config: HyperQwen2Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.LayerList( + [HyperQwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = "flash_attention_2" # config._attn_implementation + self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + # self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: paddle.Tensor = None, + attention_mask: Optional[paddle.Tensor] = None, + position_ids: Optional[paddle.Tensor] = None, + past_key_values: Optional[List[paddle.Tensor]] = None, + inputs_embeds: Optional[paddle.Tensor] = None, + image_embeds=None, + media_offset=None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, paddlenlp.transformers.model_outputs.BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + past_key_values_length = 0 + + if past_key_values is None: + past_key_values = tuple([None] * len(self.layers)) + # NOTE: to make cache can be clear in-time + past_key_values = list(past_key_values) + + past_key_values_length = seq_length + cache_length = 0 + if past_key_values[0] is not None: + cache_length = past_key_values[0][0].shape[1] # + past_key_values_length += cache_length + + if position_ids is None: + position_ids = paddle.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=paddle.int64 + ) + position_ids = position_ids.unsqueeze(0).reshape([-1, seq_length]) + else: + position_ids = position_ids.reshape([-1, seq_length]).astype(dtype="int64") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + attention_mask = None # + + hidden_states = inputs_embeds + + # beam search + if batch_size != len(media_offset): + # The model is performing beamsearch, repeat the visual content + beam_factor = batch_size // len(media_offset) + assert batch_size % len(media_offset) == 0 + media_offset = media_offset * beam_factor + image_embeds = repeat(image_embeds, "B L D -> (factor B) L D", factor=beam_factor) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () # not None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + image_embeds=image_embeds, + media_offset=media_offset, + past_key_value=past_key_value, # not past_key_values + output_attentions=output_attentions, + use_cache=use_cache, + ) + + # NOTE: clear outdate cache after it has been used for memory saving + past_key_value = past_key_values[idx] = None + + hidden_states = layer_outputs[0] + + next_decoder_cache = next_decoder_cache + (layer_outputs[-1],) if use_cache else None + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return paddlenlp.transformers.model_outputs.BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class HyperQwen2ForCausalLM(Qwen2PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = HyperQwen2Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias_attr=False) + + # Initialize weights and apply final processing + # self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def forward( + self, + input_ids: paddle.Tensor = None, + attention_mask: Optional[paddle.Tensor] = None, + position_ids: Optional[paddle.Tensor] = None, + past_key_values: Optional[List[paddle.Tensor]] = None, + inputs_embeds: Optional[paddle.Tensor] = None, + image_embeds=None, + media_offset=None, + labels: Optional[paddle.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, paddlenlp.transformers.model_outputs.CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Qwen2ForCausalLM + + >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, # [1, 74] # [1, 1] + attention_mask=attention_mask, # [1, 74] # [1, 75] + position_ids=position_ids, # [1, 74] # [1, 1] + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, # none + image_embeds=image_embeds, # [7, 729, 3584] sum 134144. + media_offset=media_offset, # [[18, 24, 30, 36, 42, 48, 54]] + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, # + ) + + hidden_states = outputs[0] # sum 6656 mean 0.02502441 + try: + logits = self.lm_head(hidden_states) + except: + logits = self.lm_head(hidden_states.cast(paddle.bfloat16)) + logits = logits.cast(paddle.float32) # sum -5314405 mean -0.47356287 + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :] + shift_labels = labels[..., 1:] + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + shift_logits = shift_logits.reshape([-1, self.config.vocab_size]) + shift_labels = shift_labels.reshape([-1]) + # Enable model parallelism + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return paddlenlp.transformers.model_outputs.CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + # 以下这段参考PaddleNLP的 Qwen2ForCausalLM 的写法,与torch的mPLUG-owl3不同 + batch_size, seq_length = input_ids.shape + position_ids = kwargs.get("position_ids", paddle.arange(seq_length).expand((batch_size, seq_length))) + attention_mask = kwargs.get("attention_mask", None) + if past_key_values: + input_ids = input_ids[:, -1].unsqueeze(axis=-1) + position_ids = position_ids[:, -1].unsqueeze(-1) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + "image_embeds": kwargs.get("image_embeds"), + "media_offset": kwargs.get("media_offset"), + } + ) + return model_inputs diff --git a/paddlemix/models/mPLUGOwl3/modeling_mplugowl3.py b/paddlemix/models/mPLUGOwl3/modeling_mplugowl3.py new file mode 100644 index 000000000..a522e5165 --- /dev/null +++ b/paddlemix/models/mPLUGOwl3/modeling_mplugowl3.py @@ -0,0 +1,264 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from threading import Thread + +import paddle +import paddle.nn as nn +from paddlenlp.generation import TextIteratorStreamer +from paddlenlp.transformers import Qwen2PretrainedModel + +from ...processors.mplugowl3_processing import ( + mPLUGOwl3ImageProcessor, + mPLUGOwl3Processor, +) +from .configuration_mplugowl3 import mPLUGOwl3Config +from .modeling_hyper_qwen2 import HyperQwen2ForCausalLM +from .modeling_navit_siglip import SigLipVisionTransformer + + +class mPLUGOwl3PreTrainedModel(Qwen2PretrainedModel): + config_class = mPLUGOwl3Config + _no_split_modules = ["HyperQwen2DecoderLayer", "SiglipVisionTransformer"] + + +class mPLUGOwl3Model(mPLUGOwl3PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.language_model = HyperQwen2ForCausalLM(config) + self.vision_model = self.init_vision_module() + self.vision_dim = self.vision_model.embed_dim + self.embed_dim = self.config.hidden_size + self.vision2text_model = nn.Sequential( + nn.Linear(self.vision_dim, self.embed_dim), nn.GELU(), nn.Linear(self.embed_dim, self.embed_dim) + ) + self.processor = None + self.terminators = ["<|im_end|>", "<|endoftext|>"] + self.vision_batch_size = config.vision_batch_size + + def init_vision_module(self): + self.config.vision_config._attn_implementation = "flash_attention_2" + model = SigLipVisionTransformer(self.config.vision_config) + setattr(model, "embed_dim", model.embeddings.embed_dim) + setattr(model, "patch_size", model.embeddings.patch_size) + return model + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.embed_tokens = value + + def get_output_embeddings(self): + return self.language_model.lm_head + + def set_output_embeddings(self, new_embeddings): + self.language_model.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.language_model = decoder + + def get_decoder(self): + return self.language_model + + def _small_batched_forward(self, pixel_values): + vision_batch_size = self.vision_batch_size + image_forward_out = [] + B = len(pixel_values) + for i in range(0, B, vision_batch_size): + start_idx = i + end_idx = min(B, i + vision_batch_size) + tmp_hs = self.vision_model(pixel_values[start_idx:end_idx], output_hidden_states=True).hidden_states[-2] + image_forward_out.append(tmp_hs) + + vision_embedding = paddle.concat(image_forward_out, axis=0) + assert vision_embedding.shape[0] == B + return vision_embedding + + def forward_image(self, pixel_values): + if pixel_values is None: + return None + dtype = self.language_model.model.embed_tokens.weight.dtype + image_embeds = self._small_batched_forward(pixel_values.to(dtype)) + + if self.vision2text_model is not None: + image_embeds = self.vision2text_model(image_embeds) + else: + pass + + return image_embeds + + def forward(self, pixel_values=None, **kwargs): + image_embeds = self.forward_image(pixel_values) + + return self.language_model(image_embeds=image_embeds, **kwargs) + + def _decode(self, input_ids, image_embeds, media_offset, tokenizer, attention_mask, decode_text=False, **kwargs): + terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators] + + # Note: must add position_ids, paddlenlp bug + batch_size, seq_length = input_ids.shape + position_ids = paddle.arange(seq_length).expand((batch_size, seq_length)) + + output = self.language_model.generate( + input_ids=input_ids, + image_embeds=image_embeds, + media_offset=media_offset, + pad_token_id=0, + eos_token_id=terminators, + position_ids=position_ids, # Note: must add position_ids + attention_mask=attention_mask, + **kwargs, + )[0] + # output = output[:,input_ids.shape[1]:] # paddle no need this + if decode_text: + return self._decode_text(output, tokenizer) + return output + + def _decode_stream(self, input_ids, image_embeds, media_offset, tokenizer, **kwargs): + terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators] + streamer = TextIteratorStreamer(tokenizer=tokenizer) + generation_kwargs = { + "input_ids": input_ids, + "image_embeds": image_embeds, + "media_offset": media_offset, + "pad_token_id": 0, + "eos_token_id": terminators, + "streamer": streamer, + } + generation_kwargs.update(kwargs) + + thread = Thread(target=self.language_model.generate, kwargs=generation_kwargs) + thread.start() + + return streamer + + def _decode_text(self, result_ids, tokenizer): + terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators] + result_text = [] + for result in result_ids: + result = result[result != 0] + if result[-1] in terminators: + result = result[:-1] + result_text.append(tokenizer.decode(result).strip()) + return result_text + + def init_processor(self, tokenizer): + ip = mPLUGOwl3ImageProcessor(image_size=378) + self.processor = mPLUGOwl3Processor(image_processor=ip, tokenizer=tokenizer) + processor = self.processor + return processor + + def generate( + self, + input_ids=None, + pixel_values=None, + media_offset=None, + attention_mask=None, + tokenizer=None, + stream=False, + decode_text=False, + **kwargs + ): + assert input_ids is not None + + with paddle.no_grad(): + image_embeds = self.forward_image(pixel_values) + + if stream: + result = self._decode_stream( + input_ids=input_ids, + image_embeds=image_embeds, + media_offset=media_offset, + tokenizer=tokenizer, + **kwargs, + ) + else: + result = self._decode( + input_ids=input_ids, + image_embeds=image_embeds, + media_offset=media_offset, + tokenizer=tokenizer, + attention_mask=attention_mask, + decode_text=decode_text, + **kwargs, + ) + + return result + + def chat( + self, + images, + videos, + messages, + tokenizer, + processor=None, + max_new_tokens=2048, + min_new_tokens=0, + sampling=True, + max_inp_length=8192, + system_prompt="", + stream=False, + max_slice_nums=None, + use_image_id=None, + **kwargs + ): + cut_flag = kwargs.get("kwargs", True) + if processor is None: + if self.processor is None: + processor = self.init_processor(tokenizer) + else: + processor = self.processor + inputs = processor(messages, images=images, videos=videos, cut_enable=cut_flag) + inputs.update( + { + "tokenizer": tokenizer, + "max_new_tokens": max_new_tokens, + # 'stream':True, + } + ) + if sampling: + generation_config = { + "top_p": 0.8, + "top_k": 100, + "temperature": 0.7, + "do_sample": True, + # "repetition_penalty": 1.05 + } + else: + generation_config = { + "num_beams": 3, + # "repetition_penalty": 1.2, + } + + if min_new_tokens > 0: + generation_config["min_new_tokens"] = min_new_tokens + + generation_config.update((k, kwargs[k]) for k in generation_config.keys() & kwargs.keys()) + with paddle.no_grad(): + res = self.generate(**inputs, stream=stream, decode_text=True, **generation_config) + + if stream: + + def stream_gen(): + for text in res: + for term in self.terminators: + text = text.replace(term, "") + yield text + + return stream_gen() + + else: + answer = res[0] + return answer diff --git a/paddlemix/models/mPLUGOwl3/modeling_navit_siglip.py b/paddlemix/models/mPLUGOwl3/modeling_navit_siglip.py new file mode 100644 index 000000000..1f399d640 --- /dev/null +++ b/paddlemix/models/mPLUGOwl3/modeling_navit_siglip.py @@ -0,0 +1,743 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import os +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import paddle +from paddle import nn +from paddlenlp.transformers import PretrainedConfig +from paddlenlp.transformers.activations import ACT2FN +from paddlenlp.transformers.model_outputs import ( + BaseModelOutput, + BaseModelOutputWithPooling, + ModelOutput, +) +from paddlenlp.transformers.model_utils import PretrainedModel + +from paddlemix.models.flash_attn_utils import has_flash_attn_func +from paddlemix.utils.initializer import _calculate_fan_in_and_fan_out + +flash_attn_func, flash_attn_varlen_func = has_flash_attn_func() + + +@dataclass +class PaddleAttentionMaskConverter: + """ + A utility attention mask class for Paddle that allows one to: + - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask + (batch_size, 1, query_length, key_value_length) + """ + + @staticmethod + def _expand_mask(mask: paddle.Tensor, dtype: str, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.shape + tgt_len = tgt_len if tgt_len is not None else src_len + + # Expand dimensions: [bsz, 1, 1, src_len] + expanded_mask = mask.unsqueeze([1, 2]) + + # Broadcast to target shape: [bsz, 1, tgt_len, src_len] + expanded_mask = paddle.expand(expanded_mask, shape=[bsz, 1, tgt_len, src_len]) + expanded_mask = expanded_mask.astype(dtype) + + # Invert the mask (1.0 for positions to attend to) + inverted_mask = 1.0 - expanded_mask + + # Replace 1s with large negative values + min_value = paddle.to_tensor(float("-1e9"), dtype=dtype) + inverted_mask = paddle.where(inverted_mask.astype("bool"), min_value, paddle.zeros_like(inverted_mask)) + + return inverted_mask + + +def _prepare_4d_attention_mask(mask: paddle.Tensor, dtype: str, tgt_len: Optional[int] = None): + """ + Creates a 4D attention mask from a 2D mask. + + Args: + mask (paddle.Tensor): A 2D attention mask of shape (batch_size, key_value_length) + dtype (str): The dtype the created mask should have + tgt_len (int, optional): The target length the created mask should have + + Returns: + paddle.Tensor: A 4D attention mask of shape (batch_size, 1, query_length, key_value_length) + """ + return PaddleAttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) + + +class SigLipVisionConfig(PretrainedConfig): + + model_type = "siglip_vision_model" + + def __init__( + self, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + image_size=224, + patch_size=16, + hidden_act="gelu", + layer_norm_eps=1e-06, + attention_dropout=0.0, + _attn_implementation="eager", + **kwargs + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self._attn_implementation = _attn_implementation + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + # cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + # get the vision config dict if we are loading from SigLipConfig + if config_dict.get("model_type") == "siglip": + config_dict = config_dict["vision_config"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + print( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + +def _trunc_normal_(tensor, mean, std, a, b): + # 确保mean是浮点数 + mean = float(mean) + std = float(std) + a = float(a) + b = float(b) + + def norm_cdf(x): + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if mean < a - 2 * std or mean > b + 2 * std: + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. The distribution of values may be incorrect.", + stacklevel=2, + ) + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + tensor.uniform_(min=2 * l - 1, max=2 * u - 1) + if tensor.dtype in ["float16", "bfloat16"]: + og_dtype = tensor.dtype + tensor = tensor.to("float32") + tensor.erfinv_() + tensor = tensor.to(og_dtype) + else: + tensor.erfinv_() + tensor.multiply_(y=paddle.to_tensor(std * math.sqrt(2.0))) + tensor.add_(y=paddle.to_tensor(mean)) + if tensor.dtype == "float16": + tensor = tensor.to("float32") + tensor.clip_(min=a, max=b) + tensor = tensor.to("float16") + else: + tensor.clip_(min=a, max=b) + + +def trunc_normal_tf_( + tensor: paddle.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0 +) -> paddle.Tensor: + """Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\\mathcal{N}( ext{mean}, ext{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \\leq ext{mean} \\leq b`. + NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the + bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 + and the result is subsquently scaled and shifted by the mean and std args. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + """ + with paddle.no_grad(): + _trunc_normal_(tensor, 0, 1.0, a, b) + tensor.multiply_(y=paddle.to_tensor(std)).add_(y=paddle.to_tensor(mean)) + + +def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + if mode == "fan_in": + denom = fan_in + elif mode == "fan_out": + denom = fan_out + elif mode == "fan_avg": + denom = (fan_in + fan_out) / 2 + variance = scale / denom + if distribution == "truncated_normal": + trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.8796256610342398) + elif distribution == "normal": + with paddle.no_grad(): + tensor.normal_(std=math.sqrt(variance)) + elif distribution == "uniform": + bound = math.sqrt(3 * variance) + with paddle.no_grad(): + tensor.uniform_(min=-bound, max=bound) + else: + raise ValueError(f"invalid distribution {distribution}") + + +def lecun_normal_(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") + + +def default_flax_embed_init(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="normal") + + +@dataclass +class SiglipVisionModelOutput(ModelOutput): + """ + Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. + Args: + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + image_embeds: Optional[paddle.Tensor] = None + last_hidden_state: paddle.float32 = None + hidden_states: Optional[Tuple[paddle.Tensor]] = None + attentions: Optional[Tuple[paddle.Tensor]] = None + + +class SiglipVisionEmbeddings(nn.Layer): + def __init__(self, config: SigLipVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Conv2D( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + self.num_patches_per_side = self.image_size // self.patch_size + self.num_patches = self.num_patches_per_side**2 + self.num_positions = self.num_patches + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + + def forward( + self, + pixel_values: paddle.Tensor, + patch_attention_mask: paddle.Tensor, + tgt_sizes: Optional[paddle.Tensor] = None, + ) -> paddle.Tensor: + batch_size = pixel_values.shape[0] + patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid] + embeddings = patch_embeds.flatten(2).transpose([0, 2, 1]) + + max_im_h, max_im_w = pixel_values.shape[2], pixel_values.shape[3] + max_nb_patches_h, max_nb_patches_w = (max_im_h // self.patch_size, max_im_w // self.patch_size) + boundaries = paddle.arange(start=1 / self.num_patches_per_side, end=1.0, step=1 / self.num_patches_per_side) + position_ids = paddle.full(shape=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0) + for batch_idx, p_attn_mask in enumerate(patch_attention_mask): + if tgt_sizes is not None: + nb_patches_h = tgt_sizes[batch_idx][0] + nb_patches_w = tgt_sizes[batch_idx][1] + else: + nb_patches_h = p_attn_mask[:, 0].sum() + nb_patches_w = p_attn_mask[0].sum() + + fractional_coords_h = paddle.arange(start=0, end=1 - 1e-06, step=1 / nb_patches_h) + fractional_coords_w = paddle.arange(start=0, end=1 - 1e-06, step=1 / nb_patches_w) + bucket_coords_h = paddle.bucketize(x=fractional_coords_h, sorted_sequence=boundaries, right=True) + bucket_coords_w = paddle.bucketize(x=fractional_coords_w, sorted_sequence=boundaries, right=True) + pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten() + position_ids[batch_idx].scatter_( + paddle.nonzero(p_attn_mask.reshape([-1]))[:, 0], pos_ids.astype(position_ids.dtype) + ) + position_ids = position_ids.to(self.position_embedding.weight.place) + + embeddings = embeddings + self.position_embedding(position_ids.cast("int64")) + return embeddings + + +class SigLipAttention(nn.Layer): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__ + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def forward( + self, + hidden_states: paddle.Tensor, + attention_mask: Optional[paddle.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + batch_size, q_len, _ = hidden_states.shape + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.reshape([batch_size, q_len, self.num_heads, self.head_dim]).transpose([0, 2, 1]) + key_states = key_states.reshape([batch_size, q_len, self.num_heads, self.head_dim]).transpose([0, 2, 1]) + value_states = value_states.reshape([batch_size, q_len, self.num_heads, self.head_dim]).transpose([0, 2, 1]) + + k_v_seq_len = key_states.shape[-2] + attn_weights = paddle.matmul(query_states, key_states.transpose([0, 1, 3, 2])) * self.scale + + if attn_weights.shape != [batch_size, self.num_heads, q_len, k_v_seq_len]: + raise ValueError( + f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" + f" {attn_weights.shape}" + ) + + if attention_mask is not None: + if attention_mask.shape != [batch_size, 1, q_len, k_v_seq_len]: + raise ValueError( + f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.shape}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, axis=-1) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = paddle.matmul(attn_weights, value_states) + + if attn_output.shape != [batch_size, self.num_heads, q_len, self.head_dim]: + raise ValueError( + f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.shape}" + ) + + attn_output = attn_output.transpose([0, 2, 1]).contiguous() + attn_output = attn_output.reshape([batch_size, q_len, self.embed_dim]) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +class SiglipFlashAttention2(SigLipAttention): + """ + Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_causal = False + + def forward( + self, + hidden_states: paddle.Tensor, + attention_mask: Optional[paddle.Tensor] = None, + position_ids: Optional[paddle.Tensor] = None, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs + ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]: + output_attentions = False + bsz, q_len, _ = hidden_states.shape + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.reshape([bsz, q_len, self.num_heads, self.head_dim]).transpose([0, 2, 1]) + key_states = key_states.reshape([bsz, q_len, self.num_heads, self.head_dim]).transpose([0, 2, 1]) + value_states = value_states.reshape([bsz, q_len, self.num_heads, self.head_dim]).transpose([0, 2, 1]) + + kv_seq_len = tuple(key_states.shape)[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + query_states = query_states.transpose([0, 2, 1]) + key_states = key_states.transpose([0, 2, 1]) + value_states = value_states.transpose([0, 2, 1]) + + dropout_rate = self.dropout if self.training else 0.0 + input_dtype = query_states.dtype + if input_dtype == paddle.float32: + if paddle.amp.is_auto_cast_enabled(): + target_dtype = paddle.amp.get_default_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate + ) + attn_output = attn_output.reshape(bsz, q_len, self.embed_dim).contiguous() + attn_output = self.out_proj(attn_output) + if not output_attentions: + attn_weights = None + return attn_output, attn_weights + + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`paddle.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`paddle.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`paddle.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`paddle.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + # Contains at least one padding token in the sequence + causal = self.is_causal and query_length != 1 + + if attention_mask is not None: + raise NotImplementedError("Currently only support single image infer and attention_mask is none") + + else: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + causal=causal, # no softmax_scale= + )[0] + + return attn_output + + +class SigLipMLP(nn.Layer): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class SigLipEncoderLayer(nn.Layer): + def __init__(self, config: SigLipVisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = SigLipAttention(config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, epsilon=config.layer_norm_eps) + self.mlp = SigLipMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, epsilon=config.layer_norm_eps) + + # Ignore copy + def forward( + self, + hidden_states: paddle.Tensor, + attention_mask: paddle.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[paddle.Tensor]: + """ + Args: + hidden_states (`paddle.Tensor`): + Input to the layer of shape `(batch, seq_len, embed_dim)`. + attention_mask (`paddle.Tensor`): + Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class SigLipPreTrainedModel(PretrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SigLipVisionConfig + base_model_prefix = "siglip" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, SiglipVisionEmbeddings): + width = self.config.hidden_size + init_Normal = nn.initializer.Normal(std=1 / np.sqrt(width)) + init_Normal(module.position_embedding.weight) + elif isinstance(module, nn.Embedding): + default_flax_embed_init(module.weight) + elif isinstance(module, SigLipAttention): + # 初始化投影层权重 + for proj in [module.q_proj, module.k_proj, module.v_proj, module.out_proj]: + init_Normal = nn.initializer.Normal() + init_Normal(proj.weight) + # 使用assign替代原地操作初始化偏置 + if hasattr(proj, "bias") and proj.bias is not None: + proj.bias.set_value(paddle.zeros_like(proj.bias)) + + elif isinstance(module, SigLipMLP): + # 初始化FC层权重 + init_Normal = nn.initializer.Normal() + init_Normal(module.fc1.weight) + init_Normal(module.fc2.weight) + + # 使用assign初始化偏置 + if hasattr(module.fc1, "bias") and module.fc1.bias is not None: + module.fc1.bias.set_value(paddle.normal(shape=module.fc1.bias.shape, mean=0.0, std=1e-06)) + if hasattr(module.fc2, "bias") and module.fc2.bias is not None: + module.fc2.bias.set_value(paddle.normal(shape=module.fc2.bias.shape, mean=0.0, std=1e-06)) + + elif isinstance(module, (nn.Linear, nn.Conv2D)): + lecun_normal_(module.weight) + if module.bias is not None: + module.bias.set_value(paddle.zeros_like(module.bias)) + + elif isinstance(module, nn.LayerNorm): + # 使用set_value替代原地操作 + if module.bias is not None: + module.bias.set_value(paddle.zeros_like(module.bias)) + if module.weight is not None: + module.weight.set_value(paddle.ones_like(module.weight)) + + +class SigLipEncoder(nn.Layer): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`SiglipEncoderLayer`]. + Args: + config: SiglipConfig + """ + + def __init__(self, config: SigLipVisionConfig): + super().__init__() + self.config = config + self.layers = nn.LayerList(sublayers=[SigLipEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + inputs_embeds, + attention_mask: Optional[paddle.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = True, + ) -> Union[Tuple, BaseModelOutput]: + """ + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + hidden_states = inputs_embeds + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + layer_outputs = encoder_layer(hidden_states, attention_mask, output_attentions=output_attentions) + hidden_states = layer_outputs[0] + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class SigLipVisionTransformer(SigLipPreTrainedModel): + config_class = SigLipVisionConfig + main_input_name = "pixel_values" + _supports_flash_attn_2 = True + + def __init__(self, config: SigLipVisionConfig): + super().__init__(config) + self.config = config + embed_dim = config.hidden_size + self.embeddings = SiglipVisionEmbeddings(config) + self.encoder = SigLipEncoder(config) + self.post_layernorm = nn.LayerNorm(normalized_shape=embed_dim, epsilon=config.layer_norm_eps) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + # self.post_init() + + def get_input_embeddings(self) -> nn.Layer: + return self.embeddings.patch_embedding + + def forward( + self, + pixel_values, + patch_attention_mask: Optional[paddle.Tensor] = None, + tgt_sizes: Optional[paddle.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = True, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + """ + Returns: + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + batch_size = pixel_values.shape[0] + if patch_attention_mask is None: + patch_attention_mask = paddle.ones( + shape=( + batch_size, + pixel_values.shape[2] // self.config.patch_size, + pixel_values.shape[3] // self.config.patch_size, + ), + dtype="bool", + ) + + hidden_states = self.embeddings( + pixel_values=pixel_values, patch_attention_mask=patch_attention_mask, tgt_sizes=tgt_sizes + ) + patch_attention_mask = patch_attention_mask.reshape([batch_size, -1]) + if not paddle.any(x=~patch_attention_mask): + attention_mask = None + else: + attention_mask = ( + _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype) + if not self._use_flash_attention_2 + else patch_attention_mask + ) + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.post_layernorm(last_hidden_state) + if not return_dict: + return (last_hidden_state, None) + encoder_outputs[1:] + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=None, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) diff --git a/paddlemix/processors/__init__.py b/paddlemix/processors/__init__.py index 7a05f5974..1d9ee61eb 100644 --- a/paddlemix/processors/__init__.py +++ b/paddlemix/processors/__init__.py @@ -16,21 +16,22 @@ from .blip_processing import * from .clip_processing import * from .eva02_processing import * +from .got_process import * from .groundingdino_processing import * +from .image_processing_minicpmv import * from .imagebind_processing import * from .internlm_xcomposer2_processing import * from .internvl_processing import * +from .janus_processing import * from .llava_next_processing import * from .llava_processing import * from .minigpt4_image_processing import * from .minigpt4_processing import * +from .mplugowl3_processing import * +from .processing_minicpmv import * from .qwen2_vl_processing import * from .qwen_vl_processing import * from .sam_processing import * from .tokenizer import SimpleTokenizer, tokenize from .visualglm_image_processing import * from .visualglm_processing import * -from .image_processing_minicpmv import * -from .processing_minicpmv import * -from .janus_processing import * -from .got_process import * diff --git a/paddlemix/processors/mplugowl3_processing.py b/paddlemix/processors/mplugowl3_processing.py new file mode 100644 index 000000000..93e9b4611 --- /dev/null +++ b/paddlemix/processors/mplugowl3_processing.py @@ -0,0 +1,824 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import re +from enum import Enum +from typing import Any, Dict, List, Optional, Union + +import paddle +import paddle.vision.transforms as transforms +from einops import rearrange, repeat +from paddle.vision.transforms import Resize +from paddlenlp.transformers.image_processing_utils import ( + BaseImageProcessor, + BatchFeature, +) +from paddlenlp.transformers.processing_utils import ProcessorMixin +from PIL import Image + +OWL_MEDIA_TOKEN = ["<|image|>"] + + +def recursive_converter(converter, value): + if isinstance(value, list): + new_value = [] + for v in value: + new_value += [recursive_converter(converter, v)] + return new_value + else: + return converter(value) + + +def box_area(boxes): + # 获取边界框的宽度和高度 + width = boxes[:, 2] - boxes[:, 0] + height = boxes[:, 3] - boxes[:, 1] + # 计算面积 + area = width * height + return area + + +def custom_max(a, b): + return paddle.where(a > b, a, b) + + +def custom_min(a, b): + return paddle.where(a < b, a, b) + + +def box_iou(boxes1, area1, boxes2, eps=1e-05): + # >>>>>> area2 = torchvision.ops.boxes.box_area(boxes2) + area1 = area1.astype("float32") + boxes1 = boxes1.astype("float32") + boxes2 = boxes2.astype("float32") + + area2 = box_area(boxes2).astype("float32") + lt = custom_max(boxes1[:, None, :2], boxes2[:, :2]) + rb = custom_min(boxes1[:, None, 2:], boxes2[:, 2:]) + wh = (rb - lt).clip(min=0) + inter = wh[:, :, 0] * wh[:, :, 1] + union = area1[:, None] + area2 - inter + iou = inter / (union + eps) + return iou, union + + +# def box_iou(boxes1, area1, boxes2, eps=1e-5): +# area2 = box_area(boxes2) + +# lt = paddle.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] +# rb = paddle.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] + +# wh = (rb - lt).clip(min=0) # [N,M,2] +# inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] + +# union = area1[:, None] + area2 - inter + +# iou = inter / (union + eps) +# return iou, union + + +available_anchor_strategy = ["docowl", "random", "highest", "last", "llava"] + +grid_dict = { + "grid_33": [ + (1, 1), + (1, 2), + (2, 1), + (1, 3), + (3, 1), + (2, 2), + (1, 4), + (4, 1), + (1, 5), + (5, 1), + (1, 6), + (6, 1), + (2, 3), + (3, 2), + (1, 7), + (7, 1), + (4, 2), + (2, 4), + (1, 8), + (8, 1), + (3, 3), + (1, 9), + (9, 1), + ], + "grid_squ_3x3": [(1, 1), (2, 2), (3, 3)], + "grid_squ_4": [(2, 2), (1, 3), (1, 4), (3, 1), (4, 1)], + "grid_squ_6": [(2, 2), (1, 3), (1, 4), (3, 1), (4, 1), (2, 3), (3, 2)], + "grid_squ_2": [(2, 1)], + "grid_squ_9": [ + (1, 1), + (1, 2), + (2, 1), + (1, 3), + (3, 1), + (2, 2), + (1, 4), + (4, 1), + (1, 5), + (5, 1), + (1, 6), + (6, 1), + (2, 3), + (3, 2), + (1, 7), + (7, 1), + (4, 2), + (2, 4), + (1, 8), + (8, 1), + (3, 3), + (1, 9), + (9, 1), + ], +} + + +cut_prompt_template_dict = { + 'v0': lambda img_token, h, w: f''.join([f"{img_token}" for i in range(h) for j in range(w)]), + 'v1': lambda img_token, h, w: f'Cut to {h} rows {w} columns, '+ ' '.join([f"subimg({i},{j}){img_token}"for i in range(h) for j in range(w)]), + 'v1_global': lambda img_token, h, w: f'Cut to {h} rows {w} columns with a global view, '+ ' '.join([f"subimg({i},{j}){img_token}"for i in range(h) for j in range(w)]+[f"global_view{img_token}"]), + 'v2_global': lambda img_token, h, w: f'Cut to {h} rows {w} columns with a global view\n'+ '\n'.join([' '.join([f"subimg({i},{j}){img_token}" for j in range(w)]) for i in range(h)])+f"\nglobal_view{img_token}", + 'v3': lambda img_token, h, w: f'<|start_cut|>{h}*{w}'+ ' '.join([f"{img_token}"for i in range(h) for j in range(w)])+'<|end_cut|>', + 'v3_global': lambda img_token, h, w: f'<|start_cut|>{h}*{w}\n'+ '\n'.join([' '.join([f"{img_token}" for j in range(w)]) for i in range(h)])+f'\n{img_token}<|end_cut|>', +} + + +def anchor_rank(anchors, anchors_areas, input_image_size, eps=1e-5): + # anchors x1 y1 x2 y2 + + # image_size: (h, w) + # xyxy + input_image_bbox = paddle.to_tensor([0, 0, input_image_size[1], input_image_size[0]]).unsqueeze(0) + + boxes1 = anchors + boxes2 = input_image_bbox + boxes3 = anchors.clone() + # y2 + boxes3[:, 3] = input_image_size[0] / input_image_size[1] * anchors[:, 2] # 用于算分辨率无关的iou + + area1 = anchors_areas + + iou, _ = box_iou(boxes1, area1, boxes2) + iou = iou.squeeze(1) + shape_iou, _ = box_iou(boxes1, area1, boxes3) + shape_iou = shape_iou.diag() + # 优先匹配形状接近 再匹配分辨率接近 + index = paddle.argmax(shape_iou * 100 + iou, axis=0) + return index + + +def select_best_resolution(anchors, anchors_areas, input_image_size): # TODO For a futher check + """ + Selects the best resolution from a list of possible resolutions based on the original size. + + Args: + original_size (tuple): The original size of the image in the format (width, height). + possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. + + Returns: + tuple: The best fit resolution in the format (width, height). + """ + original_size = (input_image_size[1], input_image_size[0]) + possible_resolutions = [(_[2], _[3]) for _ in anchors] # xyxy -> w,h + + original_width, original_height = original_size + # best_fit = None + max_effective_resolution = 0 + min_wasted_resolution = float("inf") + + index = 0 + for i, (width, height) in enumerate(possible_resolutions): + scale = min(width / original_width, height / original_height) + downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale) + effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height) + wasted_resolution = (width * height) - effective_resolution + + if effective_resolution > max_effective_resolution or ( + effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution + ): + max_effective_resolution = effective_resolution + min_wasted_resolution = wasted_resolution + # best_fit = (width, height) + index = i + + return index + + +def build_cut_shape_indices(cut_shape): + # cut_shape: a list of (nh,nw) + cut_shape_indices = [] + for shape in cut_shape: + n = shape[0] * shape[1] + indices = paddle.concat( + [repeat(paddle.to_tensor(shape), "l -> n l", n=n), paddle.arange(n).unsqueeze(1)], axis=1 + ) + assert indices.shape[0] == n + assert indices.shape[1] == 3 # nh,nw,idx + + cut_shape_indices.append(indices) + cut_shape_indices = paddle.concat(cut_shape_indices, axis=0).astype("int64") + return cut_shape_indices + + +class AnchorResize(paddle.nn.Layer): + def __init__(self, image_size, anchors, interpolation="bilinear", antialias=None, anchor_strategy="docowl"): + super().__init__() + self.image_size = image_size + # xyxy + self.anchors = paddle.to_tensor( + [[0, 0, _[1] * image_size[1], _[0] * image_size[0]] for _ in anchors], + ) + + self.anchor_areas = box_area(self.anchors) + + self.interpolation = interpolation + self.antialias = antialias + self.anchor_strategy = anchor_strategy + assert self.anchor_strategy in available_anchor_strategy + + def resize_global(self, img): + transform = Resize(size=self.image_size, interpolation=self.interpolation) + return transform(img) + + def forward(self, img, skip_resize=False): + """ + Args: + img (PIL Image or Tensor): Image to be scaled. + + Returns: + PIL Image or Tensor: Rescaled image. + """ + if self.anchor_strategy == "docowl": + selected_anchor = anchor_rank(self.anchors, self.anchor_areas, (img.size[1], img.size[0])) + elif self.anchor_strategy == "random": + selected_anchor = random.randint(0, len(self.anchors) - 1) + elif self.anchor_strategy == "highest": + # 选面积最大的 在这个基础上 尽可能选最方正的 + selected_anchor = paddle.argmax( + self.anchors[:, 2] * self.anchors[:, 3] * 100 - paddle.abs(self.anchors[:, 2] - self.anchors[:, 3]) + ) + elif self.anchor_strategy == "last": + selected_anchor = len(self.anchors) - 1 + elif self.anchor_strategy == "llava": + selected_anchor = select_best_resolution(self.anchors, self.anchor_areas, (img.size[1], img.size[0])) + else: + selected_anchor = None + assert selected_anchor is not None + + target_size = self.anchors[selected_anchor][2:].tolist() # w,h + if skip_resize: + # for debug + return selected_anchor + # return F.resize(img, [target_size[1],target_size[0]], self.interpolation, max_size=None, antialias=self.antialias), selected_anchor + # image_np = np.array(img) + # image_tensor = paddle.to_tensor(image_np, dtype="float32") + # image_tensor = image_tensor.transpose([2, 0, 1]) # 变成 (3, 500, 500) + # if self.interpolation == "bilinear" or "bicubic": + # image_tensor = image_tensor.unsqueeze(0) # 变成 (1, 3, 500, 500) + transform = Resize(size=[target_size[1], target_size[0]], interpolation=self.interpolation) + return (transform(img), selected_anchor) + # return ( + # F.interpolate( + # image_tensor, size=[target_size[1], target_size[0]], mode=self.interpolation, align_corners=False + # )[0], + # selected_anchor, + # ) + + def __repr__(self) -> str: + detail = f"(size={self.image_size}, anchor={self.anchors}, interpolation={self.interpolation.value}, antialias={self.antialias})" + return f"{self.__class__.__name__}{detail}" + + +class CutMixin: + def __init__( + self, + cut_cfg={ + "anchors": "grid_squ_6", + "anchor_strategy": "docowl", + "cut_prompt": "v3", + "add_global": True, + "cut_prob": 1.0, + }, + ) -> None: + if cut_cfg is None: + self.cut_enable = False + return + else: + self.cut_enable = True + image_size = self.image_size + anchors = cut_cfg.get("anchors", "grid_33") + anchor_strategy = cut_cfg.get("anchor_strategy", "docowl") + cut_prompt = cut_cfg.get("cut_prompt", "v0") + self.cut_prob = cut_cfg.get("cut_prob", 1.0) + + self.force_shape_cut = cut_cfg.get("force_shape_cut", False) + force_shape_cut_anchors = cut_cfg.get("force_shape_cut_anchors", "force_shape_cut_anchors") + + self.add_global = cut_cfg.get("add_global", False) + + # h,w + if isinstance(image_size, int): + image_size = (image_size, image_size) + self.image_size = image_size + + if anchors in grid_dict: + anchors = grid_dict[anchors] + else: + anchors = eval(anchors) + self.anchors = [tuple(_) for _ in anchors] + self.anchor_max = max([max(_) for _ in self.anchors]) + self.resizer = AnchorResize( + image_size=image_size, anchors=anchors, interpolation="bicubic", anchor_strategy=anchor_strategy + ) + + if force_shape_cut_anchors in grid_dict: + force_shape_cut_anchors = grid_dict[force_shape_cut_anchors] + else: + force_shape_cut_anchors = eval(force_shape_cut_anchors) + self.force_shape_cut_anchors = [tuple(_) for _ in force_shape_cut_anchors] + self.force_shape_cut_anchors_max = max([max(_) for _ in self.force_shape_cut_anchors]) + + self.old_resizer = transforms.Resize(image_size, interpolation="bicubic") + + # 把image processor的缩放去掉 只保留后面的变换 + self.image_transform = transforms.Compose(self.image_transform.transforms[1:]) + if self.add_global: + self.cut_prompt_template = cut_prompt_template_dict[cut_prompt + "_global"] + else: + self.cut_prompt_template = cut_prompt_template_dict[cut_prompt] + + self.media_tokens = ["<|image|>", "<|video|>"] + + def _process_image(self, images): + new_images = [] + cut_shape = [] + for image in images: + raw_image = image + image, selected_anchor = self.resizer(image) + image_input = self.image_transform(image) # h,w,3 -> 3,h,w + cut_shape.append( + (image_input.shape[1] // self.image_size[0], image_input.shape[2] // self.image_size[1]) + ) # cut_h, cut_w + image_input = rearrange( + image_input, "C (num_h h) (num_w w) -> (num_h num_w) C h w", h=self.image_size[0], w=self.image_size[1] + ) + + new_images.append(image_input) + + if self.add_global: + new_images.append(self.image_transform(self.resizer.resize_global(raw_image)).unsqueeze(0)) + cut_shape.append((1, 1)) + + new_images = paddle.concat(new_images, axis=0) + cut_shape_indices = build_cut_shape_indices(cut_shape) + return new_images, cut_shape, cut_shape_indices + + +class TensorType(Enum): + PADDLE = "paddle" + + +class mPLUGOwl3BatchFeature(BatchFeature): + r""" + Extend from BatchFeature for supporting various image size + """ + + def __init__(self, data: Optional[Dict[str, Any]] = None, tensor_type: Union[None, str, TensorType] = None): + super().__init__(data) + self.convert_to_tensors(tensor_type=tensor_type) + + def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None): + if tensor_type is None: + return self + + # is_tensor, as_tensor = self._get_is_as_tensor_fns(tensor_type) + is_tensor = lambda x: isinstance(x, paddle.Tensor) + as_tensor = paddle.to_tensor + + def converter(value): + try: + if not is_tensor(value): + tensor = as_tensor(value) + return tensor + except: # noqa E722 + if key == "overflowing_values": + raise ValueError("Unable to create tensor returning overflowing values of different lengths. ") + raise ValueError( + "Unable to create tensor, you should probably activate padding " + "with 'padding=True' to have batched tensors with the same length." + ) + + for key, value in self.items(): + self[key] = recursive_converter(converter, value) + return self + + +class mPLUGOwl3ImageProcessor(BaseImageProcessor, CutMixin): + model_input_names = ["pixel_values"] + + def __init__(self, image_size, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], **kwargs): + super().__init__(**kwargs) + self.image_size = image_size + self.image_transform = transforms.Compose( + [ + transforms.Resize((image_size, image_size), interpolation="bicubic"), + transforms.ToTensor(), + transforms.Normalize(mean, std), + ] + ) + CutMixin.__init__(self) + + def preprocess( + self, images: Union[Image.Image, List[Image.Image]], cut_enable=True, **kwargs + ) -> mPLUGOwl3BatchFeature: + if isinstance(images, Image.Image): + images_list = [images] + else: + images_list = images + + if self.cut_enable and cut_enable: + image_data, cut_shape, cut_shape_indices = self._process_image(images_list) + else: + image_data = [self.image_transform(self.resizer.resize_global(image)) for image in images_list] + image_data = paddle.stack(image_data, axis=0) + cut_shape = cut_shape_indices = None + + return mPLUGOwl3BatchFeature( + data={"pixel_values": image_data, "cut_shape": cut_shape, "cut_shape_indices": cut_shape_indices} + ) + + def to_dict(self): + encoder_dict = super().to_dict() + pop_keys = ["image_transform", "resizer", "old_resizer", "cut_prompt_template"] + for pk in pop_keys: + encoder_dict.pop(pk, None) + return encoder_dict + + +class MediaIndicesHelper: + def __init__(self, tokenizer) -> None: + self.media_position = [] + self.tokenizer = tokenizer + + def has_media(self, text, media_tokens=None): + if media_tokens is None: + media_tokens = OWL_MEDIA_TOKEN + has_media_flag = any([media_token == text for media_token in media_tokens]) + if any([media_token in text for media_token in media_tokens]): + # 不允许出现text中包含media token但是不仅仅是media token。 media token必须单独为一个chunk + assert has_media_flag, text + return has_media_flag + + def add_media(self, text_chunk, text=None, tokenize_fn=None): + # cross + assert tokenize_fn is not None + assert text is not None + assert text in OWL_MEDIA_TOKEN + media_token_ids = tokenize_fn(text) + start = len(text_chunk) + end = start + len(media_token_ids) + self.media_position.append([start, end]) + text_chunk.extend(media_token_ids) + return len(media_token_ids) + + def cal_media_offset(self, input_ids): + if len(self.media_position) == 0: + return paddle.ones_like(input_ids) * (-1000000) + + media_starts = paddle.to_tensor([_[0] for _ in self.media_position]).reshape([1, -1]) + rng = paddle.arange(input_ids.shape[0]).reshape([-1, 1]) + matrix = (rng > media_starts).sum(axis=1) + + return matrix + + def len_images( + self, + ): + return len(self.media_position) + + +class mPLUGOwl3Processor(ProcessorMixin): + r""" + Args: + image_processor ([`mPLUGOwl3ImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`LlamaTokenizerWrapper`], *optional*): + The tokenizer is a required input. + """ + attributes = ["image_processor", "tokenizer"] + image_processor_class = "mPLUGOwl3ImageProcessor" + tokenizer_class = "AutoTokenizer" + + def __init__( + self, + image_processor: mPLUGOwl3ImageProcessor = None, + tokenizer=None, + prompt_style="chatml", + inference_mode=True, + addition_eod="<|endoftext|>", + ): + super().__init__(image_processor, tokenizer) + self.image_processor: mPLUGOwl3ImageProcessor + self.prompt_style = prompt_style + self.inference_mode = inference_mode + self.media_tokens = ["<|image|>"] + self.addition_eod = addition_eod + + def build_text_qwen(self, messages): + # role should be within ['system', 'user', 'assistant'] + im_start, im_end = "<|im_start|>", "<|im_end|>" + + text = [] + for num_turn, message in enumerate(messages): + if num_turn == 0 and message["role"] != "system": + if self.prompt_style != "plain": + text.append({"text": f"{im_start}system\n{im_end}", "label": 0}) + if message["role"] == "system": + if self.prompt_style != "plain": + text.append({"text": f"{im_start}system\n{message['content']}{im_end}", "label": 0}) + elif message["role"] == "user": + if self.prompt_style != "plain": + content = f"\n{im_start}user\n{message['content']}{im_end}" + else: + content = message["content"] + pattern = "|".join(map(re.escape, self.media_tokens)) + chunk_strs = re.split(f"({pattern})", content) + for chunk_str in chunk_strs: + text.append({"text": chunk_str, "label": 0}) + + elif message["role"] == "assistant": + if self.prompt_style != "plain": + text.append({"text": f"\n{im_start}assistant\n", "label": 0}) + text.append({"text": f"{message['content']}{im_end}", "label": 1}) + else: + text.append({"text": f"{message['content']}", "label": 1}) + text.append({"text": self.addition_eod, "label": 1}) + else: + raise NotImplementedError + if self.inference_mode: + while text and text[-1]["label"] == 1: # 只要列表非空且最后一个元素满足条件 + text.pop() # 就移除最后一个元素 + return text + + def wrapped_tokenize(self, text): + return self.tokenizer(text).input_ids + + def encode_text_sft(self, texts): + # output enc_chunk + + enc_chunk = [] + label_chunk = [] + enc_length = 0 + + num_images = 0 + + media_helper = MediaIndicesHelper(tokenizer=self.tokenizer) + for current_ti, text_chunk in enumerate(texts): + + text = text_chunk["text"] + label = text_chunk["label"] + + if not media_helper.has_media(text): + curr_chunk = self.wrapped_tokenize(text) + if label == 1: + enc_length += len(curr_chunk) + enc_chunk += curr_chunk + label_chunk += [label] * len(curr_chunk) + else: + + enc_length += len(curr_chunk) + enc_chunk += curr_chunk + label_chunk += [label] * len(curr_chunk) + # For media tokens + else: + + add_length = media_helper.add_media(enc_chunk, text=text, tokenize_fn=self.wrapped_tokenize) + enc_length += add_length + label_chunk += [label] * add_length + # enc_chunk.extend([self.media_tokens[text]] * self.media_lengths[text]) + # enc_length += self.media_lengths[text] + # label_chunk += [label] * self.media_lengths[text] + num_images += 1 + + enc_chunk = paddle.to_tensor(enc_chunk).astype(dtype="int64") + media_offset = [paddle.to_tensor([_[0] for _ in media_helper.media_position]).astype(dtype="int64")] + return { + "input_ids": enc_chunk.unsqueeze(0), + "media_offset": media_offset, + } + + def __call__( + self, + messages, + images=None, + videos=None, + max_length: Optional[int] = None, + cut_enable=True, + return_tensors: Optional[Union[str, TensorType]] = TensorType.PADDLE, + **kwargs + ) -> mPLUGOwl3BatchFeature: + medias = [] + if videos is not None: + medias.extend([{"type": "video", "content": video, "use_video_span": True} for video in videos]) + if images is not None: + medias.extend([{"type": "image", "content": image} for image in images]) + + if len(medias): + image_tensor_list = [] + pattern = r"(<\|image\|>|<\|video\|>)" + # 存在媒体 + image_token_ptr = 0 + # media_layout = [] + for message in messages: + text_list = re.split(pattern, message["content"]) + text = "" + for text_content in text_list: + if text_content in ["<|image|>", "<|video|>"]: + media_item = medias[image_token_ptr] + image_token_ptr += 1 + if text_content == "<|image|>": + assert media_item["type"] == "image" + image = media_item["content"] + + image_inputs = self.image_processor( + [image], cut_enable=cut_enable, return_tensors=return_tensors + ) + if image_inputs.get("cut_shape", None) is not None: + cut_shape = image_inputs["cut_shape"] + cut_text = self.image_processor.cut_prompt_template( + img_token="<|image|>", h=cut_shape[0][0], w=cut_shape[0][1] + ) + text += cut_text + image_tensor_list.append(image_inputs["pixel_values"]) + else: + text += text_content + image_tensor_list.append(image_inputs["pixel_values"]) + elif text_content == "<|video|>": + assert media_item["type"] == "video" + video = media_item["content"] + use_video_span = media_item["use_video_span"] + image_tensor = self.image_processor(video, cut_enable=False)["pixel_values"] + image_tensor_list.append(image_tensor) + num_video_frame = image_tensor.shape[0] + if use_video_span: + text_content = ( + "<|start_video_frame|>" + "<|image|>" * num_video_frame + "<|end_video_frame|>" + ) + else: + text_content = "<|image|>" * num_video_frame + text += text_content + else: + text += text_content + message["content"] = text + assert image_token_ptr == len(medias), (image_token_ptr, len(medias)) # 保证图和token数目一致 + assert all(len(_.shape) == 4 for _ in image_tensor_list), [_.shape for _ in image_tensor_list] + num_image_tokens = sum([_["content"].count("<|image|>") for _ in messages]) + num_image_shapes = sum([_.shape[0] for _ in image_tensor_list]) + assert num_image_tokens == num_image_shapes, (messages, [_.shape for _ in image_tensor_list]) + + image_tensor_list = paddle.concat(image_tensor_list, axis=0) + + text = self.build_text_qwen(messages) + model_inputs = self.encode_text_sft(text) + + if len(medias) is not None: + model_inputs.update({"pixel_values": image_tensor_list}) + # if 'cut_shape' in model_inputs: + # model_inputs.pop('cut_shape') + # if 'cut_shape_indices' in model_inputs: + # model_inputs.pop('cut_shape_indices') + return mPLUGOwl3BatchFeature(model_inputs) + + def check_media(self, images, messages): + media_num = 0 if images is None else len(images) + media_count = sum([message["content"].count("<|image|>") for message in messages]) + assert media_num == media_count + + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + output_ids = args[0] + result_text = [] + for result in output_ids: + result = result[result != 0] + if result[0] == self.tokenizer.bos_id: + result = result[1:] + if result[-1] == self.tokenizer.eos_id: + result = result[:-1] + result_text.append(self.tokenizer.decode(result, *args[1:], **kwargs).strip()) + return result_text + # return self.tokenizer.batch_decode(*args, **kwargs) + + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + result = args[0] + result = result[result != 0] + if result[0] == self.tokenizer.bos_id: + result = result[1:] + if result[-1] == self.tokenizer.eos_id or ( + hasattr(self.tokenizer, "eot_id") and result[-1] == self.tokenizer.eot_id + ): + result = result[:-1] + return self.tokenizer.decode(result, *args[1:], **kwargs).strip() + + def _convert(self, input_str, max_inp_length: Optional[int] = None): + if self.version > 2.5 or not getattr(self.tokenizer, "add_bos_token", False): + input_ids = self.tokenizer.encode(input_str) + else: + input_ids = [self.tokenizer.bos_id] + self.tokenizer.encode(input_str) + if max_inp_length is not None: + input_ids = input_ids[:max_inp_length] + input_ids = paddle.to_tensor(data=input_ids, dtype="int32") + + start_cond = (input_ids == self.tokenizer.im_start_id) | (input_ids == self.tokenizer.slice_start_id) + end_cond = (input_ids == self.tokenizer.im_end_id) | (input_ids == self.tokenizer.slice_end_id) + + image_start_tokens = paddle.where(start_cond)[0] # or paddle.nonzero(start_cond)[:, 0] + image_start_tokens += 1 + image_end_tokens = paddle.where(end_cond)[0] + + valid_image_nums = max(len(image_start_tokens), len(image_end_tokens)) + + image_bounds = paddle.hstack( + [ + image_start_tokens[:valid_image_nums].unsqueeze(-1), + image_end_tokens[:valid_image_nums].unsqueeze(-1), + ] + ) + return input_ids, image_bounds + + @property + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + + def pad(self, inputs, max_length=None, padding_value=0, padding_side="left"): + items = [] + if isinstance(inputs[0], list): + assert isinstance(inputs[0][0], paddle.Tensor) + for it in inputs: + for tr in it: + items.append(tr) + else: + assert isinstance(inputs[0], paddle.Tensor) + items = inputs + + batch_size = len(items) + shape = items[0].shape + dim = len(shape) + assert dim <= 2 + if max_length is None: + max_length = 0 + max_length = max(max_length, max(item.shape[-1] for item in items)) + min_length = min(item.shape[-1] for item in items) + dtype = items[0].dtype + + if dim == 0: + return paddle.stack([item for item in items], axis=0), [0] + elif dim == 1: + if max_length == min_length: + return paddle.stack([item for item in items], axis=0), [0] * batch_size + tensor = paddle.zeros((batch_size, max_length), dtype=dtype) + padding_value + else: + tensor = paddle.zeros((batch_size, max_length, shape[-1]), dtype=dtype) + padding_value + + padding_length = [] + for i, item in enumerate(items): + if dim == 1: + if padding_side == "left": + tensor[i, -len(item) :] = item.clone() + else: + tensor[i, : len(item)] = item.clone() + elif dim == 2: + if padding_side == "left": + tensor[i, -len(item) :, :] = item.clone() + else: + tensor[i, : len(item), :] = item.clone() + padding_length.append(tensor.shape[-1] - len(item)) + + return tensor, padding_length