-
-
Couldn't load subscription status.
- Fork 10.8k
[Model] Add Ernie4.5 VL Model Support #22514
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
dee4372
af1d864
99773ff
8442d5d
42004bc
9a509cd
8d3d62b
c227368
080f818
01f2231
d4ee345
7ea25db
e124b87
0fb8105
35fe906
02754b7
4943465
9c6a49d
7e5ac16
0bedaa6
98bd72f
69a2902
faad7fe
a4a1817
4c5abbb
3b70302
a08137c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,71 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| from typing import Optional | ||
|
|
||
| import torch | ||
|
|
||
| from .mrope import MRotaryEmbedding | ||
| from .common import apply_rotary_emb_dispatch | ||
|
|
||
| class Ernie4_5_VLRotaryEmbedding(MRotaryEmbedding): | ||
| """3D rotary positional embedding. 3D is t:time h:height w:width""" | ||
|
|
||
| def forward( | ||
| self, | ||
| positions: torch.Tensor, | ||
| query: torch.Tensor, | ||
| key: Optional[torch.Tensor] = None, | ||
| ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: | ||
| """PyTorch-native implementation equivalent to forward(). | ||
|
|
||
| Args: | ||
| positions: | ||
| [num_tokens,] (text only) or | ||
| [3, num_tokens] (T/H/W positions with multimodal inputs) | ||
| query: [num_tokens, num_heads * head_size] | ||
| key: [num_tokens, num_kv_heads * head_size] | ||
| """ | ||
CSWYF3634076 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| assert positions.ndim == 1 or positions.ndim == 2 | ||
| assert key is not None | ||
|
|
||
| num_tokens = positions.shape[-1] | ||
| cos_sin = self.cos_sin_cache[positions] | ||
| cos, sin = cos_sin.chunk(2, dim=-1) | ||
| if positions.ndim == 2: | ||
| assert self.mrope_section | ||
|
|
||
| section_h = self.mrope_section[0] # 22 | ||
| section_w = self.mrope_section[1] # 22 | ||
| section_t = self.mrope_section[2] # 20 | ||
| assert section_h == section_w | ||
| # Split according to [h w h w h w h w... t t t...] | ||
| section_cos_t, section_cos_h, section_cos_w = cos[..., -section_t :], \ | ||
| cos[..., : section_h + section_w : 2], \ | ||
| cos[..., 1 : section_h + section_w : 2], | ||
| cos_t, cos_h, cos_w = section_cos_t[0], section_cos_h[1], section_cos_w[2] | ||
| cos_hw = torch.stack([cos_h, cos_w], dim=-1).reshape(cos_h.shape[:-1] + (cos_h.shape[-1] * 2,)) | ||
| cos = torch.cat([cos_hw, cos_t], dim=-1) | ||
|
|
||
| section_sin_t, section_sin_h, section_sin_w = sin[..., -section_t :], \ | ||
| sin[..., : section_h + section_w : 2], \ | ||
| sin[..., 1 : section_h + section_w : 2], | ||
| sin_t, sin_h, sin_w = section_sin_t[0], section_sin_h[1], section_sin_w[2] | ||
| sin_hw = torch.stack([sin_h, sin_w], dim=-1).reshape(sin_h.shape[:-1] + (sin_h.shape[-1] * 2,)) | ||
| sin = torch.cat([sin_hw, sin_t], dim=-1) | ||
|
|
||
| query_shape = query.shape | ||
| query = query.view(num_tokens, -1, self.head_size) | ||
| query_rot = query[..., :self.rotary_dim] | ||
| query_pass = query[..., self.rotary_dim:] | ||
| query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style) | ||
| query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) | ||
|
|
||
| key_shape = key.shape | ||
| key = key.view(num_tokens, -1, self.head_size) | ||
| key_rot = key[..., :self.rotary_dim] | ||
| key_pass = key[..., self.rotary_dim:] | ||
| key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style) | ||
| key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) | ||
| return query, key | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -158,6 +158,15 @@ | |
| context_len=context_len, | ||
| seq_len=seq_len, | ||
| ) | ||
| elif hf_config.model_type in ["ernie4_5_moe_vl", "ernie4_5_vl"]: | ||
| return cls._ernie_get_input_positions_tensor( | ||
| input_tokens=input_tokens, | ||
| hf_config=hf_config, | ||
| image_grid_thw=image_grid_thw, | ||
| video_grid_thw=video_grid_thw, | ||
| context_len=context_len, | ||
| seq_len=seq_len, | ||
| ) | ||
| else: | ||
| return cls._vl_get_input_positions_tensor( | ||
| input_tokens=input_tokens, | ||
|
|
@@ -278,6 +287,118 @@ | |
| len(input_tokens)).item() | ||
| return llm_positions, mrope_position_delta | ||
|
|
||
| @classmethod | ||
| def _ernie_get_input_positions_tensor( | ||
|
Comment on lines
+525
to
+526
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I also have a feeling that a lot of operations in here can be optimized by vectorization, but we can revisit this later |
||
| cls, | ||
| input_tokens: list[int], | ||
| hf_config: PretrainedConfig, | ||
| image_grid_thw: Union[list[list[int]], torch.Tensor], | ||
| video_grid_thw: Union[list[list[int]], torch.Tensor], | ||
| context_len: int = 0, | ||
| seq_len: Optional[int] = None, | ||
| ) -> tuple[torch.Tensor, int]: | ||
| """Get mrope input positions and delta value for Ernie VL.""" | ||
|
|
||
| image_token_id = hf_config.im_patch_id | ||
| video_start_token_id = hf_config.video_start_token_id | ||
| video_end_token_id = hf_config.video_end_token_id | ||
| spatial_conv_size = hf_config.spatial_conv_size | ||
| temporal_conv_size = hf_config.temporal_conv_size | ||
| llm_pos_ids_list: list = [] | ||
|
|
||
| if not (image_grid_thw is None and video_grid_thw is None): | ||
| if isinstance(image_grid_thw, torch.Tensor): | ||
| image_grid_thw = image_grid_thw.tolist() | ||
|
|
||
| input_token_type: list[str] = [] | ||
| video_check_flg = False | ||
| for token in input_tokens: | ||
| if token == video_start_token_id: | ||
| video_check_flg = True | ||
| elif token == video_end_token_id: | ||
| video_check_flg = False | ||
|
|
||
| if (token == image_token_id) and (video_check_flg is False): | ||
| input_token_type.append("image") | ||
| elif (token == image_token_id) and (video_check_flg is True): | ||
| input_token_type.append("video") | ||
| else: | ||
| input_token_type.append("text") | ||
|
|
||
| input_type_group: list[tuple[str, int, int]] = [] | ||
| for key, group_iter in itertools.groupby( | ||
| enumerate(input_token_type), lambda x: x[1]): | ||
| group_list = list(group_iter) | ||
| start_index = group_list[0][0] | ||
| end_index = group_list[-1][0] + 1 | ||
| input_type_group.append((key, start_index, end_index)) | ||
|
|
||
| video_frame_num = 1 | ||
| mm_data_idx = 0 | ||
| for modality_type, start_idx, end_idx in input_type_group: | ||
| st_idx = llm_pos_ids_list[-1].max() + 1 if len( | ||
| llm_pos_ids_list) > 0 else 0 | ||
| if modality_type == "image": | ||
| t, h, w = ( | ||
| image_grid_thw[mm_data_idx][0], | ||
| image_grid_thw[mm_data_idx][1], | ||
| image_grid_thw[mm_data_idx][2], | ||
| ) | ||
| llm_grid_t, llm_grid_h, llm_grid_w = \ | ||
| t, h // spatial_conv_size, w // spatial_conv_size | ||
|
|
||
| t_index = torch.arange(llm_grid_t).view(-1, 1).expand( | ||
| -1, llm_grid_h * llm_grid_w).flatten() | ||
| h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand( | ||
| llm_grid_t, -1, llm_grid_w).flatten() | ||
| w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand( | ||
| llm_grid_t, llm_grid_h, -1).flatten() | ||
| llm_pos_ids_list.append( | ||
| torch.stack([t_index, h_index, w_index]) + st_idx) | ||
| mm_data_idx += 1 | ||
|
|
||
| elif modality_type == "video": | ||
| t, h, w = ( | ||
| video_grid_thw[mm_data_idx][0], | ||
| video_grid_thw[mm_data_idx][1], | ||
| video_grid_thw[mm_data_idx][2], | ||
| ) | ||
| llm_grid_t, llm_grid_h, llm_grid_w = \ | ||
| t // temporal_conv_size, h // spatial_conv_size, w // spatial_conv_size | ||
|
|
||
| for t_idx in range(llm_grid_t): | ||
| t_index = torch.tensor(t_idx).view(-1, 1).expand( | ||
| -1, llm_grid_h * llm_grid_w).flatten() | ||
| h_index = torch.arange(llm_grid_h).view( | ||
| 1, -1, 1).expand(1, -1, llm_grid_w).flatten() | ||
| w_index = torch.arange(llm_grid_w).view( | ||
| 1, 1, -1).expand(1, llm_grid_h, -1).flatten() | ||
| llm_pos_ids_list.append( | ||
| torch.stack([t_index, h_index, w_index]) + st_idx) | ||
|
|
||
| mm_data_idx += 1 | ||
| video_frame_num += 1 | ||
|
|
||
| else: | ||
| text_len = end_idx - start_idx | ||
| llm_pos_ids_list.append( | ||
| torch.arange(text_len).view(1, -1).expand(3, -1) + | ||
| st_idx) | ||
| video_frame_num = 1 | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| else: | ||
| text_len = len(input_tokens) | ||
| llm_pos_ids_list.append( | ||
| torch.arange(text_len).view(1, -1).expand(3, -1)) | ||
|
|
||
| llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) | ||
| llm_positions = llm_positions[:, context_len:seq_len] | ||
| mrope_position_delta = (llm_positions.max() + 1 - | ||
| len(input_tokens)).item() | ||
| return llm_positions, mrope_position_delta | ||
|
|
||
|
|
||
|
|
||
| @classmethod | ||
| def _vl_get_input_positions_tensor( | ||
| cls, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.