diff --git a/verl/models/transformers/kimi_vl.py b/verl/models/transformers/kimi_vl.py index b2133a9c57b..87a43e29e49 100644 --- a/verl/models/transformers/kimi_vl.py +++ b/verl/models/transformers/kimi_vl.py @@ -19,51 +19,7 @@ from transformers.cache_utils import Cache from transformers.modeling_flash_attention_utils import _flash_attention_forward -from verl.utils.ulysses import gather_heads_scatter_seq, gather_outpus_and_unpad, gather_seq_scatter_heads, get_ulysses_sequence_parallel_group, get_ulysses_sequence_parallel_rank, get_ulysses_sequence_parallel_world_size, validate_ulysses_config - - -def _merge_with_image_features( - self, - inputs_embeds: torch.Tensor, - input_ids: torch.Tensor, - image_features: torch.Tensor, -): - """ - Args: - inputs_embeds (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length, input_embed_dim)`): - The input embeddings. - input_ids (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`): - The input ids. - image_features (:obj:`torch.Tensor` of shape :obj:`(image_token_nums, image_feature_dim)`): - The image features to merge with the input embeddings. - """ - image_token_index: int = self.config.media_placeholder_token_id - - batch_size, sequence_length, input_embed_dim = inputs_embeds.shape - image_feature_nums, image_feature_dim = image_features.shape - - assert image_feature_dim == input_embed_dim - - image_token_nums = (input_ids == image_token_index).sum() - total_image_token_nums = torch.tensor([image_token_nums], dtype=image_token_nums.dtype, device=input_ids.device) - total_image_token_nums = gather_outpus_and_unpad(total_image_token_nums, gather_dim=0) # [sp_size] - assert image_feature_nums == total_image_token_nums.sum() - - # (batch_size, sequence_length / sp, input_embed_dim) -> (batch_size * sequence_length / sp, input_embed_dim) - inputs_embeds = inputs_embeds.reshape(-1, input_embed_dim) - - # (batch_size, sequence_length / sp) -> (batch_size * sequence_length / sp) - input_ids = input_ids.flatten() - - # split image features and fill in the image token positions if there are image tokens - sp_image_features = image_features.split(total_image_token_nums.tolist(), dim=0) - sp_rank = get_ulysses_sequence_parallel_rank() - image_features = sp_image_features[sp_rank] - inputs_embeds[input_ids == image_token_index] = image_features - - inputs_embeds = inputs_embeds.reshape((batch_size, sequence_length, input_embed_dim)) - - return inputs_embeds +from verl.utils.ulysses import gather_heads_scatter_seq, gather_seq_scatter_heads, get_ulysses_sequence_parallel_world_size, validate_ulysses_config # Copied from transformers.models.llama.modeling_llama.rotate_half @@ -140,7 +96,6 @@ def _ulysses_flash_attn_forward( else: q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) - q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) # Flash attention requires the input to have the shape # batch_size x seq_length x head_dim x hidden_dim @@ -151,46 +106,41 @@ def _ulysses_flash_attn_forward( kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv)).view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).transpose(1, 2) k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - kv_seq_len = value_states.shape[-2] - # patch to get all emb + # patch ulysses_sp_size = get_ulysses_sequence_parallel_world_size() - kv_seq_len *= ulysses_sp_size + if ulysses_sp_size > 1: + validate_ulysses_config(self.num_heads, ulysses_sp_size) + + num_key_value_groups = self.config.num_attention_heads // self.config.num_key_value_heads + k_pe = repeat_kv(k_pe, ulysses_sp_size) # to keep heads=1 after a2a + k_nope = repeat_kv(k_nope, num_key_value_groups) + value_states = repeat_kv(value_states, num_key_value_groups) + q = gather_seq_scatter_heads(q, seq_dim=2, head_dim=1) + k_pe = gather_seq_scatter_heads(k_pe, seq_dim=2, head_dim=1) + k_nope = gather_seq_scatter_heads(k_nope, seq_dim=2, head_dim=1) + value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1) + # (batch_size, num_head / sp_size, seq_length, head_size) + full_q_len = q.size(2) # full_q_len = seq_length + + else: + full_q_len = q_len - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + cos, sin = self.rotary_emb(value_states, seq_len=full_q_len) q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) - query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + query_states = k_pe.new_empty(bsz, self.num_heads // ulysses_sp_size, full_q_len, self.q_head_dim) query_states[:, :, :, : self.qk_nope_head_dim] = q_nope query_states[:, :, :, self.qk_nope_head_dim :] = q_pe - key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + key_states = k_pe.new_empty(bsz, self.num_heads // ulysses_sp_size, full_q_len, self.q_head_dim) key_states[:, :, :, : self.qk_nope_head_dim] = k_nope key_states[:, :, :, self.qk_nope_head_dim :] = k_pe if self.q_head_dim != self.v_head_dim: value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim]) - # patch - if ulysses_sp_size > 1: - validate_ulysses_config(self.num_heads, ulysses_sp_size) - - num_key_value_groups = self.config.num_attention_heads // self.config.num_key_value_heads - key_states = repeat_kv(key_states, num_key_value_groups) - value_states = repeat_kv(value_states, num_key_value_groups) - query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1) - key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1) - value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1) - # (batch_size, num_head / sp_size, seq_length, head_size) - full_q_len = query_states.size(2) # full_q_len = seq_length - - position_ids_list = [torch.empty_like(position_ids) for _ in range(ulysses_sp_size)] - torch.distributed.all_gather(position_ids_list, position_ids, group=get_ulysses_sequence_parallel_group()) - position_ids = torch.concat(position_ids_list, dim=-1) - - else: - full_q_len = q_len - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. query_states = query_states.transpose(1, 2) diff --git a/verl/models/transformers/monkey_patch.py b/verl/models/transformers/monkey_patch.py index 0e0ec87bde4..cebff6b2ba8 100644 --- a/verl/models/transformers/monkey_patch.py +++ b/verl/models/transformers/monkey_patch.py @@ -214,12 +214,14 @@ def apply_monkey_patch( elif model.config.model_type == "kimi_vl": if use_remove_padding or ulysses_sp_size > 1: # TODO: Changes need to be made when transformers are adapted. - from verl.models.transformers.kimi_vl import _merge_with_image_features, _ulysses_flash_attn_forward + from verl.models.transformers.kimi_vl import _ulysses_flash_attn_forward - module.KimiVLForConditionalGeneration._merge_with_image_features = _merge_with_image_features module.DeepseekV3FlashAttention2.forward = _ulysses_flash_attn_forward print("Monkey patch FlashAttention2.forward in KimiVL") + if ulysses_sp_size > 1: + patch_vlm_for_ulysses_input_slicing(module.DeepseekV3ForCausalLM) + if use_fused_kernels: print(f"Not support fused kernels for KimiVL")