diff --git a/optimum/habana/transformers/models/falcon/modeling_falcon.py b/optimum/habana/transformers/models/falcon/modeling_falcon.py index b6b4887407..ebb141fa5d 100644 --- a/optimum/habana/transformers/models/falcon/modeling_falcon.py +++ b/optimum/habana/transformers/models/falcon/modeling_falcon.py @@ -55,14 +55,14 @@ def gaudi_falcon_rotary_embedding_forward(self, query, key, seq_len, position_id cos, sin = self.cos_sin(seq_len, past_key_values_length, position_ids, query.device, query.dtype) query_expansion_factor = int(query.shape[0] / cos.shape[0]) - if query_expansion_factor > 1: + if query_expansion_factor > 1 and cos.shape[0] > 1: query_cos = torch.repeat_interleave(cos, query_expansion_factor, dim=0) query_sin = torch.repeat_interleave(sin, query_expansion_factor, dim=0) else: query_cos, query_sin = cos, sin key_expansion_factor = int(key.shape[0] / cos.shape[0]) - if key_expansion_factor > 1: + if key_expansion_factor > 1 and cos.shape[0] > 1: if key_expansion_factor != query_expansion_factor: key_cos = torch.repeat_interleave(cos, key_expansion_factor, dim=0) key_sin = torch.repeat_interleave(sin, key_expansion_factor, dim=0)