diff --git a/optimum/habana/transformers/models/llava/modeling_llava.py b/optimum/habana/transformers/models/llava/modeling_llava.py index 8119f442c5..36ebb8316a 100644 --- a/optimum/habana/transformers/models/llava/modeling_llava.py +++ b/optimum/habana/transformers/models/llava/modeling_llava.py @@ -151,7 +151,8 @@ def forward( # 1. Extra the input embeddings inputs_embeds = self.get_input_embeddings()(input_ids) - + + image_features = None # 2. Merge text and images if pixel_values is not None and input_ids.shape[1] != 1: image_outputs = self.vision_tower(