From b865c6c1447533f2174069da1b20f91c2360b370 Mon Sep 17 00:00:00 2001 From: Dudi Lester Date: Thu, 28 Mar 2024 15:16:04 +0200 Subject: [PATCH] enforce recompute flag on fsdpa quantization --- optimum/habana/transformers/models/llama/modeling_llama.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 1d998decc8..c588b63309 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -1,3 +1,4 @@ +import os import math import warnings from typing import List, Optional, Tuple, Union @@ -319,7 +320,8 @@ def pre_attn_forward( if q_len == 1: # next token - with ht.sdp_kernel(enable_recompute=False): + use_recompute = True if os.getenv("QUANT_CONFIG", "") else False + with ht.sdp_kernel(enable_recompute=use_recompute): attn_output = self.fused_scaled_dot_product_attention( query_states, key_states, value_states, attention_mask, 0.0, False, None )