Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,7 @@ def attention_all_reduce(self, attn_output):

def post_attn_forward(self, attn_output):
if hasattr(self.dense, "all_reduce"):
self.dense.post_all_reduce(attn_output)
return self.dense.post_all_reduce(attn_output)
return attn_output


Expand All @@ -598,7 +598,7 @@ def mlp_all_reduce(self, x):

def post_mlp_forward(self, x):
if hasattr(self.dense_4h_to_h, "all_reduce"):
self.dense_4h_to_h.post_all_reduce(x)
return self.dense_4h_to_h.post_all_reduce(x)
return x


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def attention_all_reduce(self, attn_output):

def post_attn_forward(self, attn_output):
if hasattr(self.o_proj, "post_all_reduce"):
self.o_proj.post_all_reduce(attn_output)
return self.o_proj.post_all_reduce(attn_output)
return attn_output


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -760,7 +760,7 @@ def attention_all_reduce(self, attn_output):

def post_attn_forward(self, attn_output):
if hasattr(self.o_proj, "post_all_reduce"):
self.o_proj.post_all_reduce(attn_output)
return self.o_proj.post_all_reduce(attn_output)
return attn_output


Expand Down
6 changes: 2 additions & 4 deletions optimum/habana/transformers/models/modeling_all_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,5 @@ def all_reduce(self, input):
dist.inference_all_reduce(input, group=self.mp_group)

def post_all_reduce(self, input):
# inplace addition needed for correct results
if self.bias is not None:
input += self.bias
return input
output = input + self.bias if (self.bias is not None) else input
return output
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ def attention_all_reduce(self, attn_output):

def post_attn_forward(self, attn_output):
if hasattr(self.o_proj, "post_all_reduce"):
self.o_proj.post_all_reduce(attn_output)
return self.o_proj.post_all_reduce(attn_output)
return attn_output


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ def attention_all_reduce(self, attn_output):

def post_attn_forward(self, attn_output):
if hasattr(self.o_proj, "post_all_reduce"):
self.o_proj.post_all_reduce(attn_output)
return self.o_proj.post_all_reduce(attn_output)
return attn_output


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
###############################################################################

import math
import os
from typing import List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -307,7 +308,8 @@ def pre_attn_forward(

if q_len == 1:
# next token
with ht.sdp_kernel(enable_recompute=False):
use_recompute = True if os.getenv("QUANT_CONFIG", "") else False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
use_recompute = True if os.getenv("QUANT_CONFIG", "") else False
use_recompute = bool(os.getenv("QUANT_CONFIG"))

personal suggestion, less code!

with ht.sdp_kernel(enable_recompute=use_recompute):
attn_output = FusedSDPA.apply(
query_states, key_states, value_states, attention_mask, 0.0, False, None
)
Expand Down Expand Up @@ -374,7 +376,7 @@ def attention_all_reduce(self, attn_output):

def post_attn_forward(self, attn_output):
if hasattr(self.o_proj, "post_all_reduce"):
self.o_proj.post_all_reduce(attn_output)
return self.o_proj.post_all_reduce(attn_output)
return attn_output


Expand Down