diff --git a/optimum/habana/transformers/models/mistral/modeling_mistral.py b/optimum/habana/transformers/models/mistral/modeling_mistral.py index f5c79fed14..03ce7a4984 100644 --- a/optimum/habana/transformers/models/mistral/modeling_mistral.py +++ b/optimum/habana/transformers/models/mistral/modeling_mistral.py @@ -40,7 +40,7 @@ apply_rotary_pos_emb, ) from transformers.utils import logging - +from optimum.habana.transformers.models.modeling_all_models import KVCache from ...modeling_attn_mask_utils import ( _gaudi_prepare_4d_causal_attention_mask, ) @@ -55,19 +55,6 @@ logger = logging.get_logger(__name__) -def update(prev, cur, dim, idx): - orig_cur = cur - if prev.shape == cur.shape: - # Initialize - prev.copy_(cur) - return orig_cur - assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" - if idx is not None: - return prev.index_copy_(dim, idx - 1, cur) - else: - return torch.cat((prev, cur), dim=dim) - - # Copied from transformers.models.llama.modeling_llama.repeat_kv def gaudi_mistral_repeat_kv( query_states: torch.Tensor, @@ -138,34 +125,33 @@ def gaudi_mistral_rmsnorm_forward(self, hidden_states): class GaudiMistralAttention(MistralAttention): def __init__(self, config: MistralConfig, layer_idx: int): super().__init__(config) - self.past_key = None - self.past_value = None + self.k_cache = KVCache() + self.v_cache = KVCache() + # TODO: replace these two + #self.past_key = None + #self.past_value = None self.layer_idx = layer_idx - def allocate_kv_cache(self, batch_size, seq_len): - key_shape = (batch_size, self.num_key_value_heads, seq_len, self.head_dim) - value_shape = (batch_size, self.num_key_value_heads, seq_len, self.head_dim) - if self.past_key is None or self.past_key.shape != key_shape: - # if not hasattr(self, 'past_key') or self.past_key.shape != key_shape: - device = self.k_proj.weight.device - dtype = self.k_proj.weight.dtype - self.past_key = torch.empty(key_shape, dtype=dtype, device=device) - self.past_value = torch.empty(value_shape, dtype=dtype, device=device) + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + cache_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim) + device = self.k_proj.weight.device + dtype = self.config.torch_dtype + self.k_cache.allocate(inp_seq_len, dtype, device, cache_shape) + self.v_cache.allocate(inp_seq_len, dtype, device, cache_shape) def reorder(self, tensor, beam_idx, dim_a, dim_b): updated = tensor.index_select(0, beam_idx) tensor.copy_(updated) def reorder_kv_cache(self, beam_idx: torch.LongTensor): - if self.past_key is None: - # if not hasattr(self, 'past_key'): + if self.k_cache.cache is None: return (None, None) - head_dim = self.past_key.size(-1) - seq_length = self.past_key.size(-2) - self.reorder(self.past_key, beam_idx, seq_length, head_dim) - self.reorder(self.past_value, beam_idx, seq_length, head_dim) - return (self.past_key.shape, self.past_value.shape) + head_dim = self.k_cache.cache.size(-1) + seq_length = self.k_cache.cache.size(-2) + self.reorder(self.k_cache.cache, beam_idx, seq_length, head_dim) + self.reorder(self.v_cache.cache, beam_idx, seq_length, head_dim) + return (self.k_cache.cache.shape, self.v_cache.cache.shape) def forward( self, @@ -222,27 +208,28 @@ def forward( cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - if past_key_value is not None or reuse_cache: - if reuse_cache: - past_key = self.past_key - past_value = self.past_value - else: - past_key = past_key_value[0] - past_value = past_key_value[1] - key_states = update(past_key, key_states, 2, token_idx) - value_states = update(past_value, value_states, 2, token_idx) if use_cache: + # reuse k, v, self_attention if reuse_cache: - past_key_value = (key_states.contiguous().shape, value_states.contiguous().shape) + key_states = self.k_cache(key_states, 2, token_idx) + value_states = self.v_cache(value_states, 2, token_idx) + past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape()) else: - past_key_value = (key_states.contiguous(), value_states.contiguous()) + if past_key_value is None: + past_key = torch.zeros(key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device) + past_value = torch.zeros(key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device) + past_key_value = (past_key, past_value) + key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) + value_states = self.v_cache.update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len) + + if cache_idx is not None and q_len == 1: + key_states = key_states[:, :, :cache_idx, :] + value_states = value_states[:, :, :cache_idx, :] + if attention_mask is not None: + attention_mask = attention_mask[:, :, :, :cache_idx] + kv_seq_len = key_states.shape[-2] else: past_key_value = None - if cache_idx is not None and q_len == 1: - key_states = key_states[:, :, :cache_idx, :] - value_states = value_states[:, :, :cache_idx, :] - attention_mask = attention_mask[:, :, :, :cache_idx] - kv_seq_len = key_states.shape[-2] # repeat k/v heads if n_kv_heads < n_heads query_states, key_states, value_states, attention_mask = gaudi_mistral_repeat_kv( @@ -306,8 +293,8 @@ def __init__(self, config: MistralConfig, layer_idx: int): self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - def allocate_kv_cache(self, batch_size, seq_len): - self.self_attn.allocate_kv_cache(batch_size, seq_len) + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + self.self_attn.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) def reorder_kv_cache(self, beam_idx: torch.LongTensor): return self.self_attn.reorder_kv_cache(beam_idx) @@ -376,9 +363,9 @@ def forward( class GaudiMistralModel(MistralModel): - def allocate_kv_cache(self, batch_size, seq_len): + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): for layer in self.layers: - layer.allocate_kv_cache(batch_size, seq_len) + layer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) def reorder_kv_cache(self, beam_idx: torch.LongTensor): return tuple(layer.reorder_kv_cache(beam_idx) for layer in self.layers) @@ -544,8 +531,8 @@ def forward( class GaudiMistralForCausalLM(MistralForCausalLM): - def allocate_kv_cache(self, batch_size, seq_len, _): - self.model.allocate_kv_cache(batch_size, seq_len) + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) def reorder_kv_cache(self, beam_idx: torch.LongTensor): return self.model.reorder_kv_cache(beam_idx) diff --git a/optimum/habana/transformers/models/modeling_all_models.py b/optimum/habana/transformers/models/modeling_all_models.py index 98b3d53a1f..88c3fa81b2 100644 --- a/optimum/habana/transformers/models/modeling_all_models.py +++ b/optimum/habana/transformers/models/modeling_all_models.py @@ -175,3 +175,44 @@ def all_reduce(self, input): def post_all_reduce(self, input): output = input + self.bias if (self.bias is not None) else input return output + + +class KVCache(torch.nn.Module): + def __init__(self): + super(KVCache, self).__init__() + self.cache = None + self.inp_seq_len = -1 + + def allocate(self, inp_seq_len, dtype, device, shape): + if self.cache is None or self.cache.shape != shape: + self.inp_seq_len = inp_seq_len + self.cache = torch.zeros(shape, dtype=dtype, device=device) + else: + assert ( + self.inp_seq_len == inp_seq_len + ), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}" + self.cache.fill_(0) + + def update(self, prev, cur, dim, idx, inp_seq_len): + orig_cur = cur + if prev.shape == cur.shape: + prev.copy_(cur) + return orig_cur + if cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]: + # Initialize + prev[:, :, :inp_seq_len, :].copy_(cur) + return orig_cur + assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" + if idx is not None: + prev.index_copy_(dim, idx - 1, cur) + return prev + else: + return torch.cat((prev, cur), dim=dim) + + def get_shape(self): + if self.cache is None: + return None + return self.cache.shape + + def forward(self, cur, dim, idx): + return self.update(self.cache, cur, dim, idx, self.inp_seq_len)