diff --git a/optimum/habana/transformers/models/falcon/modeling_falcon.py b/optimum/habana/transformers/models/falcon/modeling_falcon.py index 2f11870450..b6b4887407 100644 --- a/optimum/habana/transformers/models/falcon/modeling_falcon.py +++ b/optimum/habana/transformers/models/falcon/modeling_falcon.py @@ -54,8 +54,25 @@ def gaudi_falcon_rotary_embedding_forward(self, query, key, seq_len, position_id """ cos, sin = self.cos_sin(seq_len, past_key_values_length, position_ids, query.device, query.dtype) + query_expansion_factor = int(query.shape[0] / cos.shape[0]) + if query_expansion_factor > 1: + query_cos = torch.repeat_interleave(cos, query_expansion_factor, dim=0) + query_sin = torch.repeat_interleave(sin, query_expansion_factor, dim=0) + else: + query_cos, query_sin = cos, sin + + key_expansion_factor = int(key.shape[0] / cos.shape[0]) + if key_expansion_factor > 1: + if key_expansion_factor != query_expansion_factor: + key_cos = torch.repeat_interleave(cos, key_expansion_factor, dim=0) + key_sin = torch.repeat_interleave(sin, key_expansion_factor, dim=0) + else: + key_cos, key_sin = query_cos, query_sin + else: + key_cos, key_sin = cos, sin + if FusedRoPE: - return FusedRoPE.apply(query, cos, sin, 0), FusedRoPE.apply(key, cos, sin, 0) + return FusedRoPE.apply(query, query_cos, query_sin, 0), FusedRoPE.apply(key, key_cos, key_sin, 0) else: return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin)