diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 1d998decc8..c588b63309 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -1,3 +1,4 @@ +import os import math import warnings from typing import List, Optional, Tuple, Union @@ -319,7 +320,8 @@ def pre_attn_forward( if q_len == 1: # next token - with ht.sdp_kernel(enable_recompute=False): + use_recompute = True if os.getenv("QUANT_CONFIG", "") else False + with ht.sdp_kernel(enable_recompute=use_recompute): attn_output = self.fused_scaled_dot_product_attention( query_states, key_states, value_states, attention_mask, 0.0, False, None )