Skip to content

Commit

Permalink
remove some comments
Browse files Browse the repository at this point in the history
  • Loading branch information
nemonameless committed Dec 18, 2024
1 parent 754c0d6 commit a71f200
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 97 deletions.
105 changes: 11 additions & 94 deletions paddlemix/models/mPLUGOwl3/modeling_hyper_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,6 @@ def __init__(self, config: HyperQwen2Config, layer_idx: Optional[int] = None, is
def apply_mi_rope(self, key_layer, image_pos, length_each_img):
# input shape should be [s b h d]
key_layer = rearrange(key_layer, "b h s d -> s b h d")
# if self.rotary_emb_core.inv_freq.device!=key_layer.device:
# self.rotary_emb_core.inv_freq = self.rotary_emb_core.inv_freq.to(key_layer.device)
rotary_pos_emb_max_seq_len = self.config.max_position_embeddings
ntk_alpha = 1
rotary_pos_emb = self.rotary_emb_core(rotary_pos_emb_max_seq_len, ntk_alpha=ntk_alpha)
Expand Down Expand Up @@ -369,41 +367,26 @@ def hyperattention(

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
# kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)

# print('query_states, key_states', query_states.sum().item(), key_states.sum().item())
# 29952.0 492.0
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
# print('query_states, key_states', query_states.sum().item(), key_states.sum().item())
# 18304.0 -776.0
# print('query_states, key_states', query_states.shape, key_states.shape)
# [1, 28, 1, 128] [1, 4, 1, 128]

if past_key_value is not None:
# cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
# key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = paddle.concat([past_key_value[0], key_states], axis=2)
value_states = paddle.concat([past_key_value[1], value_states], axis=2)
past_key_value = (key_states, value_states) if use_cache else None
# print('query_states key_states, value_states', query_states.sum().item(), key_states.sum().item(), value_states.sum().item())
# print('query_states key_states, value_states', query_states.shape, key_states.shape, value_states.shape)
# q k v [1, 28, 74, 128] [1, 4, 74, 128] [1, 4, 74, 128]
# q k v [1, 28, 1, 128] [1, 4, 75, 128] [1, 4, 75, 128]

key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
# -5440. -1712.

# add visual to kv
length_each_img = image_embeds.shape[1]
# [7, 729, 3584] sum 78336. mean 0.00430298
try:
image_embeds = self.v_kv_proj(image_embeds)
except:
image_embeds = self.v_kv_proj(image_embeds.astype("bfloat16"))
# [7, 729, 1024] sum 184320.
image_start = 0
context_layer = []
for bi, media_starts in enumerate(media_offset):
Expand Down Expand Up @@ -432,8 +415,6 @@ def hyperattention(
H=self.num_key_value_heads,
) # b h s d
image_start += num_images
# print("curr_query_layer", bi, curr_visual_key_layer.sum().item(), curr_visual_value_layer.sum().item())
# [1, 4, 5103, 128] 206848. -22400.0

curr_visual_key_layer = self.apply_mi_rope(
curr_visual_key_layer, media_starts, length_each_img=length_each_img
Expand All @@ -459,14 +440,7 @@ def hyperattention(
causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
full_mask = causal_mask

# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# # Reference: https://github.com/pytorch/pytorch/issues/112577.
# if curr_query_layer.device.type == "cuda" and full_mask is not None:
# curr_query_layer = curr_query_layer.contiguous()
# curr_key_layer = curr_key_layer.contiguous()
# curr_value_layer = curr_value_layer.contiguous()

# full_mask.shape [1, 1, 74, 5177] # sum 196689
# Note: 注意paddle的scaled_dot_product_attention 中q k v维度与torch不同
attn_output = paddle.nn.functional.scaled_dot_product_attention(
curr_query_layer.transpose(
[0, 2, 1, 3]
Expand All @@ -481,7 +455,6 @@ def hyperattention(
dropout_p=self.attention_dropout if self.training else 0.0,
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal=is_causal,
# enable_gqa=True, # gqa can not be used because mask requires XFORMERS and not support gqa
) # -> (N, ..., L, Ev)
# torch attn_output.shape [1, 28, 72, 128]
attn_output = attn_output.transpose([0, 2, 1, 3])
Expand All @@ -490,7 +463,6 @@ def hyperattention(
attn_output = context_layer = paddle.concat(context_layer, axis=0)

attn_output = attn_output.transpose([0, 2, 1, 3])
# print('attn_output', attn_output.shape) # [1, 74, 28, 128] [1, 1, 28, 128]
attn_output = attn_output.reshape([bsz, q_len, self.hidden_size])

attn_output = self.o_proj(attn_output)
Expand Down Expand Up @@ -526,7 +498,6 @@ def forward(
# )

if self.is_hyper_enabled and image_embeds is not None:
# 必走这个分支
return self.hyperattention(
hidden_states,
attention_mask,
Expand Down Expand Up @@ -558,34 +529,26 @@ def forward(

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
# kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)

query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

if past_key_value is not None:
# cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
# key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = paddle.concat([past_key_value[0], key_states], axis=2)
value_states = paddle.concat([past_key_value[1], value_states], axis=2)
past_key_value = (key_states, value_states) if use_cache else None

key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

if attention_mask is not None: # (1,1,1,60)
if attention_mask is not None:
if tuple(attention_mask.shape) != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {bsz, 1, q_len, kv_seq_len}, but is {tuple(attention_mask.shape)}"
)
# # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# # Reference: https://github.com/pytorch/pytorch/issues/112577.
# if query_states.device.type == "cuda" and attention_mask is not None:
# query_states = query_states.contiguous()
# key_states = key_states.contiguous()
# value_states = value_states.contiguous()

# Note: 注意paddle的scaled_dot_product_attention 中q k v维度与torch不同
attn_output = paddle.nn.functional.scaled_dot_product_attention(
query_states.transpose([0, 2, 1, 3]), # [1, 28, 74, 128] sum 21632.
key_states.transpose([0, 2, 1, 3]), # [1, 28, 74, 128] sum 335872.
Expand All @@ -604,6 +567,7 @@ def forward(


# Original Attention of Qwen2
# PaddleNLP only has Qwen2Attention
QWEN2_ATTENTION_CLASSES = {
"eager": Qwen2Attention,
"flash_attention_2": Qwen2Attention, # Qwen2FlashAttention2,
Expand All @@ -616,13 +580,8 @@ def __init__(self, config: HyperQwen2Config, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size

if config.use_sliding_window and config._attn_implementation != "flash_attention_2":
logger.warning_once(
f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
"unexpected results may be encountered."
)
self.is_hyper_enabled = (layer_idx + 1) in config.hyper_layers
# print('layer_idx', layer_idx, self.is_hyper_enabled)
# TODO: 若使用Qwen2Attention则回答结果不对,若都使用HyperQwen2SdpaAttention回答结果也对,但需check一下
if 1: # self.is_hyper_enabled:
self.self_attn = HyperQwen2SdpaAttention(config, layer_idx, is_hyper_enabled=self.is_hyper_enabled)
else:
Expand All @@ -646,32 +605,29 @@ def forward(
) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
hidden_states (`paddle.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`paddle.Tensor`, *optional*): attention mask of size
`(batch, sequence_length)` where padding elements are indicated by 0.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
past_key_value (`Tuple(paddle.Tensor)`, *optional*): cached past key and value projection states
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)

# Shared LayerNorm
if image_embeds is not None and self.is_hyper_enabled:
# 134144
image_embeds = self.input_layernorm(image_embeds)
# 78336.
media_kwargs = {"image_embeds": image_embeds, "media_offset": media_offset}
else:
image_embeds = media_offset = None
media_kwargs = {}

# Self Attention
# hidden_states.sum 76.50000000
hidden_states, self_attn_weights, present_key_value = self.self_attn( # -704. 2080. (48128., 240.)
hidden_states=hidden_states.cast(paddle.bfloat16), # [1, 74, 3584] sum -704.
attention_mask=attention_mask,
Expand All @@ -682,7 +638,6 @@ def forward(
**media_kwargs, # {}
)
hidden_states = residual + hidden_states
# -1.71093750 + -704.

# Fully Connected
residual = hidden_states
Expand Down Expand Up @@ -757,34 +712,6 @@ def get_input_embeddings(self):
def set_input_embeddings(self, value):
self.embed_tokens = value

@staticmethod
def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values_length, dtype):
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
if len(attention_mask.shape) == 2:
expanded_attn_mask = _expand_2d_mask(attention_mask, dtype, tgt_length=input_shape[-1])
# For decoding phase in generation, seq_length = 1, we don't need to add causal mask
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
past_key_values_length=past_key_values_length,
)
expanded_attn_mask = expanded_attn_mask & combined_attention_mask
# [bsz, seq_len, seq_len] -> [bsz, 1, seq_len, seq_len]
elif len(attention_mask.shape) == 3:
expanded_attn_mask = attention_mask.unsqueeze(1).astype("bool")
# if attention_mask is already 4-D, do nothing
else:
expanded_attn_mask = attention_mask
else:
expanded_attn_mask = _make_causal_mask(
input_shape,
past_key_values_length=past_key_values_length,
)
# Convert bool attention_mask to float attention mask, which will be added to attention_scores later
expanded_attn_mask = paddle.where(expanded_attn_mask, 0.0, paddle.finfo(dtype).min).astype(dtype)
return expanded_attn_mask

def forward(
self,
input_ids: paddle.Tensor = None,
Expand Down Expand Up @@ -819,12 +746,6 @@ def forward(

past_key_values_length = 0

# if use_cache:
# use_legacy_cache = False #not isinstance(past_key_values, Cache)
# #if use_legacy_cache:
# # past_key_values = DynamicCache.from_legacy_cache(past_key_values)
# past_key_values_length = past_key_values.get_usable_length(seq_length)

if past_key_values is None:
past_key_values = tuple([None] * len(self.layers))
# NOTE: to make cache can be clear in-time
Expand All @@ -836,7 +757,6 @@ def forward(
cache_length = past_key_values[0][0].shape[1] #
past_key_values_length += cache_length

# print('position_ids before', position_ids)
if position_ids is None:
position_ids = paddle.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=paddle.int64
Expand All @@ -845,14 +765,10 @@ def forward(
else:
position_ids = position_ids.reshape([-1, seq_length]).astype(dtype="int64")

# print('position_ids', position_ids)
# print('seq_length', seq_length)
# print('past_key_values_length', past_key_values_length)

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

attention_mask = None
attention_mask = None #

hidden_states = inputs_embeds

Expand All @@ -867,7 +783,7 @@ def forward(
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () # not none
next_decoder_cache = () # not None

for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
Expand Down Expand Up @@ -1040,6 +956,7 @@ def forward(
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
# 以下这段参考PaddleNLP的 Qwen2ForCausalLM 的写法,与torch的mPLUG-owl3不同
batch_size, seq_length = input_ids.shape
position_ids = kwargs.get("position_ids", paddle.arange(seq_length).expand((batch_size, seq_length)))
attention_mask = kwargs.get("attention_mask", None)
Expand Down
4 changes: 1 addition & 3 deletions paddlemix/models/mPLUGOwl3/modeling_mplugowl3.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,7 @@ def _small_batched_forward(self, pixel_values):
end_idx = min(B, i + vision_batch_size)
tmp_hs = self.vision_model(pixel_values[start_idx:end_idx], output_hidden_states=True).hidden_states[-2]
image_forward_out.append(tmp_hs)
# image_forward_out[0].sum()
# [7, 729, 1152] sum -872448.

vision_embedding = paddle.concat(image_forward_out, axis=0)
assert vision_embedding.shape[0] == B
return vision_embedding
Expand All @@ -95,7 +94,6 @@ def forward_image(self, pixel_values):

if self.vision2text_model is not None:
image_embeds = self.vision2text_model(image_embeds)
# [7, 729, 3584] sum 134144. mean 0.00735474
else:
pass

Expand Down

0 comments on commit a71f200

Please sign in to comment.