Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 22 additions & 72 deletions verl/models/transformers/kimi_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,51 +19,7 @@
from transformers.cache_utils import Cache
from transformers.modeling_flash_attention_utils import _flash_attention_forward

from verl.utils.ulysses import gather_heads_scatter_seq, gather_outpus_and_unpad, gather_seq_scatter_heads, get_ulysses_sequence_parallel_group, get_ulysses_sequence_parallel_rank, get_ulysses_sequence_parallel_world_size, validate_ulysses_config


def _merge_with_image_features(
self,
inputs_embeds: torch.Tensor,
input_ids: torch.Tensor,
image_features: torch.Tensor,
):
"""
Args:
inputs_embeds (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length, input_embed_dim)`):
The input embeddings.
input_ids (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`):
The input ids.
image_features (:obj:`torch.Tensor` of shape :obj:`(image_token_nums, image_feature_dim)`):
The image features to merge with the input embeddings.
"""
image_token_index: int = self.config.media_placeholder_token_id

batch_size, sequence_length, input_embed_dim = inputs_embeds.shape
image_feature_nums, image_feature_dim = image_features.shape

assert image_feature_dim == input_embed_dim

image_token_nums = (input_ids == image_token_index).sum()
total_image_token_nums = torch.tensor([image_token_nums], dtype=image_token_nums.dtype, device=input_ids.device)
total_image_token_nums = gather_outpus_and_unpad(total_image_token_nums, gather_dim=0) # [sp_size]
assert image_feature_nums == total_image_token_nums.sum()

# (batch_size, sequence_length / sp, input_embed_dim) -> (batch_size * sequence_length / sp, input_embed_dim)
inputs_embeds = inputs_embeds.reshape(-1, input_embed_dim)

# (batch_size, sequence_length / sp) -> (batch_size * sequence_length / sp)
input_ids = input_ids.flatten()

# split image features and fill in the image token positions if there are image tokens
sp_image_features = image_features.split(total_image_token_nums.tolist(), dim=0)
sp_rank = get_ulysses_sequence_parallel_rank()
image_features = sp_image_features[sp_rank]
inputs_embeds[input_ids == image_token_index] = image_features

inputs_embeds = inputs_embeds.reshape((batch_size, sequence_length, input_embed_dim))

return inputs_embeds
from verl.utils.ulysses import gather_heads_scatter_seq, gather_seq_scatter_heads, get_ulysses_sequence_parallel_world_size, validate_ulysses_config


# Copied from transformers.models.llama.modeling_llama.rotate_half
Expand Down Expand Up @@ -140,7 +96,6 @@ def _ulysses_flash_attn_forward(
else:
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)

# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
Expand All @@ -151,46 +106,41 @@ def _ulysses_flash_attn_forward(
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv)).view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).transpose(1, 2)

k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
kv_seq_len = value_states.shape[-2]

# patch to get all emb
# patch
ulysses_sp_size = get_ulysses_sequence_parallel_world_size()
kv_seq_len *= ulysses_sp_size
if ulysses_sp_size > 1:
validate_ulysses_config(self.num_heads, ulysses_sp_size)

num_key_value_groups = self.config.num_attention_heads // self.config.num_key_value_heads
k_pe = repeat_kv(k_pe, ulysses_sp_size) # to keep heads=1 after a2a
k_nope = repeat_kv(k_nope, num_key_value_groups)
value_states = repeat_kv(value_states, num_key_value_groups)
q = gather_seq_scatter_heads(q, seq_dim=2, head_dim=1)
k_pe = gather_seq_scatter_heads(k_pe, seq_dim=2, head_dim=1)
k_nope = gather_seq_scatter_heads(k_nope, seq_dim=2, head_dim=1)
value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1)
# (batch_size, num_head / sp_size, seq_length, head_size)
full_q_len = q.size(2) # full_q_len = seq_length

else:
full_q_len = q_len

cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
cos, sin = self.rotary_emb(value_states, seq_len=full_q_len)
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)

query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
query_states = k_pe.new_empty(bsz, self.num_heads // ulysses_sp_size, full_q_len, self.q_head_dim)
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
query_states[:, :, :, self.qk_nope_head_dim :] = q_pe

key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
key_states = k_pe.new_empty(bsz, self.num_heads // ulysses_sp_size, full_q_len, self.q_head_dim)
key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
key_states[:, :, :, self.qk_nope_head_dim :] = k_pe

if self.q_head_dim != self.v_head_dim:
value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim])

# patch
if ulysses_sp_size > 1:
validate_ulysses_config(self.num_heads, ulysses_sp_size)

num_key_value_groups = self.config.num_attention_heads // self.config.num_key_value_heads
key_states = repeat_kv(key_states, num_key_value_groups)
value_states = repeat_kv(value_states, num_key_value_groups)
query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1)
key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1)
value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1)
# (batch_size, num_head / sp_size, seq_length, head_size)
full_q_len = query_states.size(2) # full_q_len = seq_length

position_ids_list = [torch.empty_like(position_ids) for _ in range(ulysses_sp_size)]
torch.distributed.all_gather(position_ids_list, position_ids, group=get_ulysses_sequence_parallel_group())
position_ids = torch.concat(position_ids_list, dim=-1)

else:
full_q_len = q_len

# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
query_states = query_states.transpose(1, 2)
Expand Down
6 changes: 4 additions & 2 deletions verl/models/transformers/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,12 +214,14 @@ def apply_monkey_patch(
elif model.config.model_type == "kimi_vl":
if use_remove_padding or ulysses_sp_size > 1:
# TODO: Changes need to be made when transformers are adapted.
from verl.models.transformers.kimi_vl import _merge_with_image_features, _ulysses_flash_attn_forward
from verl.models.transformers.kimi_vl import _ulysses_flash_attn_forward

module.KimiVLForConditionalGeneration._merge_with_image_features = _merge_with_image_features
module.DeepseekV3FlashAttention2.forward = _ulysses_flash_attn_forward
print("Monkey patch FlashAttention2.forward in KimiVL")

if ulysses_sp_size > 1:
patch_vlm_for_ulysses_input_slicing(module.DeepseekV3ForCausalLM)

if use_fused_kernels:
print(f"Not support fused kernels for KimiVL")

Expand Down
Loading