Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions optimum/habana/transformers/models/modeling_all_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,5 +164,7 @@ def all_reduce(self, input):
dist.inference_all_reduce(input, group=self.mp_group)

def post_all_reduce(self, input):
output = input + self.bias if (self.bias is not None) else input
return output
# inplace addition needed for correct results
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@skavulya very nice catch!

Does this affect any other tests or models (falcon, llama, gemma, qwen2, qwen2_moe)?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only looked at llama and qwen2. They were not affected because they didn't use a bias. I can test falcon, and qwen2 moe.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@skavulya if llama and qwen2 did not use a bias, then the output is wrong isn't it? since based on the original code output=input + input

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For llama and qwen2, bias was None so the addition is skipped.

if self.bias is not None:
input += self.bias
return input
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,19 @@

class GaudiStarcoder2MLP(Starcoder2MLP):
def pre_mlp_forward(self, x):
inputs = self.act_fn(self.gate_proj(x)) * self.up_proj(x)
output = self.down_proj(inputs)
return output
x = self.c_fc(x)
x = self.act(x)
x = self.c_proj(x)
x = F.dropout(x, p=self.residual_dropout, training=self.training)
return x

def mlp_all_reduce(self, x):
if hasattr(self.down_proj, "all_reduce"):
self.down_proj.all_reduce(x)
if hasattr(self.c_proj, "all_reduce"):
self.c_proj.all_reduce(x)

def post_mlp_forward(self, x):
if hasattr(self.down_proj, "post_all_reduce"):
return self.down_proj.post_all_reduce(x)
if hasattr(self.c_proj, "post_all_reduce"):
return self.c_proj.post_all_reduce(x)
return x


Expand Down Expand Up @@ -431,13 +433,10 @@ def forward(
flash_attention_causal_mask=flash_attention_causal_mask,
cache_idx=cache_idx,
)
hidden_states = residual + hidden_states

# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
self.self_attn.attention_all_reduce(hidden_states)
hidden_states, residual = self.post_attn_pre_mlp(hidden_states, residual)
self.mlp.mlp_all_reduce(hidden_states)
hidden_states = self.post_mlp(hidden_states, residual)

outputs = (hidden_states,)

Expand Down