2222from transformers .cache_utils import Cache , DynamicCache
2323from transformers .generation import GenerationMixin
2424from transformers .integrations import use_kernel_forward_from_hub
25- from transformers .masking_utils import create_causal_mask , create_sliding_window_causal_mask
25+ from transformers .masking_utils import (
26+ create_causal_mask ,
27+ create_sliding_window_causal_mask ,
28+ )
2629from transformers .modeling_flash_attention_utils import FlashAttentionKwargs
2730from transformers .modeling_layers import GradientCheckpointingLayer
2831from transformers .modeling_outputs import (
@@ -146,7 +149,9 @@ def forward(
146149 value_states = self .v_proj (hidden_states ).view (hidden_shape ).transpose (1 , 2 )
147150
148151 cos , sin = position_embeddings
149- query_states , key_states = apply_rotary_pos_emb (query_states , key_states , cos , sin )
152+ query_states , key_states = apply_rotary_pos_emb (
153+ query_states , key_states , cos , sin
154+ )
150155
151156 if past_key_value is not None :
152157 # sin and cos are specific to RoPE models; cache_position needed for the static cache
@@ -187,9 +192,7 @@ def __init__(self, config: Qwen2Config, layer_idx: int):
187192 self .self_attn = Qwen2Attention (config = config , layer_idx = layer_idx )
188193
189194 self .mlp = Qwen2MLP (config )
190- self .input_layernorm = Qwen2RMSNorm (
191- config .hidden_size , eps = config .rms_norm_eps
192- )
195+ self .input_layernorm = Qwen2RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
193196 self .post_attention_layernorm = Qwen2RMSNorm (
194197 config .hidden_size , eps = config .rms_norm_eps
195198 )
@@ -381,9 +384,9 @@ def forward(
381384 }
382385 # The sliding window alternating layers are not always activated depending on the config
383386 if self .has_sliding_layers :
384- causal_mask_mapping [
385- "sliding_attention"
386- ] = create_sliding_window_causal_mask ( ** mask_kwargs )
387+ causal_mask_mapping ["sliding_attention" ] = (
388+ create_sliding_window_causal_mask ( ** mask_kwargs )
389+ )
387390
388391 hidden_states = inputs_embeds
389392
@@ -843,4 +846,4 @@ def forward(
843846 "Qwen2ForSequenceClassification" ,
844847 "Qwen2ForTokenClassification" ,
845848 "Qwen2ForQuestionAnswering" ,
846- ]
849+ ]
0 commit comments