Skip to content

Commit

Permalink
refine if
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanlehome committed Feb 24, 2025
1 parent d824c2a commit 3102788
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions paddlenlp/experimental/transformers/fused_transformer_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1494,8 +1494,6 @@ def forward(
)
kwargs["max_enc_len_this_time"] = max_enc_len_this_time
kwargs["max_dec_len_this_time"] = max_dec_len_this_time
self.prefill_phase = max_enc_len_this_time[0] > 0
self.decode_phase = max_dec_len_this_time[0] > 0

if self.config.append_attn:

Expand Down Expand Up @@ -2970,7 +2968,7 @@ def compute_mla_absorb(

out_linear_out = paddle.zeros(shape=[ln_out.shape[0], self.embed_dim], dtype=ln_out.dtype)

Check warning on line 2969 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L2969

Added line #L2969 was not covered by tests

if self.prefill_phase: # prefill phase
if kwargs["max_enc_len_this_time"]: # prefill phase
qkv_out_inner = self.compute_qkv_linear(ln_out, i, latent_cache=latent_cache, **kwargs)

Check warning on line 2972 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L2971-L2972

Added lines #L2971 - L2972 were not covered by tests

from paddlenlp_ops import append_attention

Check warning on line 2974 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L2974

Added line #L2974 was not covered by tests
Expand Down Expand Up @@ -3027,7 +3025,7 @@ def compute_mla_absorb(

# print(f"prefill {i}: out_linear_out: {out_linear_out}")

if self.decode_phase: # decode phase
if kwargs["max_dec_len_this_time"]: # decode phase
if self.config.mla_config.q_lora_rank is not None:
query = paddle.matmul(ln_out, self.q_a_proj_weights[i])
query = self.norm_func(

Check warning on line 3031 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L3028-L3031

Added lines #L3028 - L3031 were not covered by tests
Expand Down Expand Up @@ -3331,7 +3329,7 @@ def compute_mla_absorb(

out_linear_out = paddle.zeros(shape=[ln_out.shape[0], self.embed_dim], dtype=ln_out.dtype)

Check warning on line 3330 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L3330

Added line #L3330 was not covered by tests

if self.prefill_phase: # prefill phase
if kwargs["max_enc_len_this_time"]: # prefill phase
qkv_out_inner = self.compute_qkv_linear(ln_out, i, latent_cache=latent_cache, **kwargs)

Check warning on line 3333 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L3332-L3333

Added lines #L3332 - L3333 were not covered by tests

from paddlenlp_ops import append_attention

Check warning on line 3335 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L3335

Added line #L3335 was not covered by tests
Expand Down Expand Up @@ -3388,7 +3386,7 @@ def compute_mla_absorb(

# print(f"prefill {i}: out_linear_out: {out_linear_out}")

if self.decode_phase: # decode phase
if kwargs["max_dec_len_this_time"]: # decode phase
if self.config.mla_config.q_lora_rank is not None:
query = weight_only_linear(

Check warning on line 3391 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L3389-L3391

Added lines #L3389 - L3391 were not covered by tests
ln_out,
Expand Down

0 comments on commit 3102788

Please sign in to comment.