diff --git a/optimum/habana/transformers/models/falcon/modeling_falcon.py b/optimum/habana/transformers/models/falcon/modeling_falcon.py index 60b7899e72..f4c9d454ab 100644 --- a/optimum/habana/transformers/models/falcon/modeling_falcon.py +++ b/optimum/habana/transformers/models/falcon/modeling_falcon.py @@ -578,7 +578,7 @@ def attention_all_reduce(self, attn_output): def post_attn_forward(self, attn_output): if hasattr(self.dense, "all_reduce"): - self.dense.post_all_reduce(attn_output) + return self.dense.post_all_reduce(attn_output) return attn_output @@ -598,7 +598,7 @@ def mlp_all_reduce(self, x): def post_mlp_forward(self, x): if hasattr(self.dense_4h_to_h, "all_reduce"): - self.dense_4h_to_h.post_all_reduce(x) + return self.dense_4h_to_h.post_all_reduce(x) return x diff --git a/optimum/habana/transformers/models/gemma/modeling_gemma.py b/optimum/habana/transformers/models/gemma/modeling_gemma.py index 1c270b62f6..5ce1330911 100755 --- a/optimum/habana/transformers/models/gemma/modeling_gemma.py +++ b/optimum/habana/transformers/models/gemma/modeling_gemma.py @@ -357,7 +357,7 @@ def attention_all_reduce(self, attn_output): def post_attn_forward(self, attn_output): if hasattr(self.o_proj, "post_all_reduce"): - self.o_proj.post_all_reduce(attn_output) + return self.o_proj.post_all_reduce(attn_output) return attn_output diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index ce7d3cc283..9b3bdd6388 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -760,7 +760,7 @@ def attention_all_reduce(self, attn_output): def post_attn_forward(self, attn_output): if hasattr(self.o_proj, "post_all_reduce"): - self.o_proj.post_all_reduce(attn_output) + return self.o_proj.post_all_reduce(attn_output) return attn_output diff --git a/optimum/habana/transformers/models/modeling_all_models.py b/optimum/habana/transformers/models/modeling_all_models.py index a769220242..90aa2d5e0f 100644 --- a/optimum/habana/transformers/models/modeling_all_models.py +++ b/optimum/habana/transformers/models/modeling_all_models.py @@ -164,7 +164,5 @@ def all_reduce(self, input): dist.inference_all_reduce(input, group=self.mp_group) def post_all_reduce(self, input): - # inplace addition needed for correct results - if self.bias is not None: - input += self.bias - return input + output = input + self.bias if (self.bias is not None) else input + return output diff --git a/optimum/habana/transformers/models/qwen2/modeling_qwen2.py b/optimum/habana/transformers/models/qwen2/modeling_qwen2.py index 7bd8ebcd9b..9c779799c5 100644 --- a/optimum/habana/transformers/models/qwen2/modeling_qwen2.py +++ b/optimum/habana/transformers/models/qwen2/modeling_qwen2.py @@ -419,7 +419,7 @@ def attention_all_reduce(self, attn_output): def post_attn_forward(self, attn_output): if hasattr(self.o_proj, "post_all_reduce"): - self.o_proj.post_all_reduce(attn_output) + return self.o_proj.post_all_reduce(attn_output) return attn_output diff --git a/optimum/habana/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/optimum/habana/transformers/models/qwen2_moe/modeling_qwen2_moe.py index fa1d8aae53..5b9da828cd 100755 --- a/optimum/habana/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/optimum/habana/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -491,7 +491,7 @@ def attention_all_reduce(self, attn_output): def post_attn_forward(self, attn_output): if hasattr(self.o_proj, "post_all_reduce"): - self.o_proj.post_all_reduce(attn_output) + return self.o_proj.post_all_reduce(attn_output) return attn_output diff --git a/optimum/habana/transformers/models/starcoder2/modeling_starcoder2.py b/optimum/habana/transformers/models/starcoder2/modeling_starcoder2.py index b01a176368..00d9de7193 100644 --- a/optimum/habana/transformers/models/starcoder2/modeling_starcoder2.py +++ b/optimum/habana/transformers/models/starcoder2/modeling_starcoder2.py @@ -17,6 +17,7 @@ ############################################################################### import math +import os from typing import List, Optional, Tuple, Union import torch @@ -307,7 +308,8 @@ def pre_attn_forward( if q_len == 1: # next token - with ht.sdp_kernel(enable_recompute=False): + use_recompute = True if os.getenv("QUANT_CONFIG", "") else False + with ht.sdp_kernel(enable_recompute=use_recompute): attn_output = FusedSDPA.apply( query_states, key_states, value_states, attention_mask, 0.0, False, None ) @@ -374,7 +376,7 @@ def attention_all_reduce(self, attn_output): def post_attn_forward(self, attn_output): if hasattr(self.o_proj, "post_all_reduce"): - self.o_proj.post_all_reduce(attn_output) + return self.o_proj.post_all_reduce(attn_output) return attn_output