From 23d4adad2c9aae947ab7d09ba1fdacee2b3d0ff5 Mon Sep 17 00:00:00 2001 From: Shiv Kaul Date: Fri, 5 Apr 2024 15:52:17 -0700 Subject: [PATCH 1/4] create kvcache object in mistral --- .../models/mistral/modeling_mistral.py | 48 ++++++++++++++----- 1 file changed, 37 insertions(+), 11 deletions(-) diff --git a/optimum/habana/transformers/models/mistral/modeling_mistral.py b/optimum/habana/transformers/models/mistral/modeling_mistral.py index f5c79fed14..b3e33a1eae 100644 --- a/optimum/habana/transformers/models/mistral/modeling_mistral.py +++ b/optimum/habana/transformers/models/mistral/modeling_mistral.py @@ -54,19 +54,45 @@ logger = logging.get_logger(__name__) +class KVCache(torch.nn.Module): + def __init__(self): + super(KVCache, self).__init__() + self.cache = None + self.inp_seq_len = -1 + + def allocate(self, inp_seq_len, dtype, device, shape): + if self.cache is None or self.cache.shape != shape: + self.inp_seq_len = inp_seq_len + self.cache = torch.zeros(shape, dtype=dtype, device=device) + else: + assert ( + self.inp_seq_len == inp_seq_len + ), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}" + self.cache.fill_(0) + + def update(self, prev, cur, dim, idx, inp_seq_len): + orig_cur = cur + if prev.shape == cur.shape: + prev.copy_(cur) + return orig_cur + if cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]: + # Initialize + prev[:, :, :inp_seq_len, :].copy_(cur) + return orig_cur + assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" + if idx is not None: + prev.index_copy_(dim, idx - 1, cur) + return prev + else: + return torch.cat((prev, cur), dim=dim) -def update(prev, cur, dim, idx): - orig_cur = cur - if prev.shape == cur.shape: - # Initialize - prev.copy_(cur) - return orig_cur - assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" - if idx is not None: - return prev.index_copy_(dim, idx - 1, cur) - else: - return torch.cat((prev, cur), dim=dim) + def get_shape(self): + if self.cache is None: + return None + return self.cache.shape + def forward(self, cur, dim, idx): + return self.update(self.cache, cur, dim, idx, self.inp_seq_len) # Copied from transformers.models.llama.modeling_llama.repeat_kv def gaudi_mistral_repeat_kv( From ae4b9fe0f91cdbecd702b84c85f1defd9449ec0e Mon Sep 17 00:00:00 2001 From: Shiv Kaul Date: Fri, 5 Apr 2024 23:30:01 +0000 Subject: [PATCH 2/4] replace with new kv cache objects --- .../models/mistral/modeling_mistral.py | 66 +++++++++---------- 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/optimum/habana/transformers/models/mistral/modeling_mistral.py b/optimum/habana/transformers/models/mistral/modeling_mistral.py index b3e33a1eae..e6e2f69efd 100644 --- a/optimum/habana/transformers/models/mistral/modeling_mistral.py +++ b/optimum/habana/transformers/models/mistral/modeling_mistral.py @@ -164,34 +164,33 @@ def gaudi_mistral_rmsnorm_forward(self, hidden_states): class GaudiMistralAttention(MistralAttention): def __init__(self, config: MistralConfig, layer_idx: int): super().__init__(config) - self.past_key = None - self.past_value = None + self.k_cache = KVCache() + self.v_cache = KVCache() + # TODO: replace these two + #self.past_key = None + #self.past_value = None self.layer_idx = layer_idx def allocate_kv_cache(self, batch_size, seq_len): - key_shape = (batch_size, self.num_key_value_heads, seq_len, self.head_dim) - value_shape = (batch_size, self.num_key_value_heads, seq_len, self.head_dim) - if self.past_key is None or self.past_key.shape != key_shape: - # if not hasattr(self, 'past_key') or self.past_key.shape != key_shape: - device = self.k_proj.weight.device - dtype = self.k_proj.weight.dtype - self.past_key = torch.empty(key_shape, dtype=dtype, device=device) - self.past_value = torch.empty(value_shape, dtype=dtype, device=device) + cache_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim) + device = self.k_proj.weight.device + dtype = self.config.torch_dtype + self.k_cache.allocate(inp_seq_len, dtype, device, cache_shape) + self.v_cache.allocate(inp_seq_len, dtype, device, cache_shape) def reorder(self, tensor, beam_idx, dim_a, dim_b): updated = tensor.index_select(0, beam_idx) tensor.copy_(updated) def reorder_kv_cache(self, beam_idx: torch.LongTensor): - if self.past_key is None: - # if not hasattr(self, 'past_key'): + if self.k_cache.cache is None: return (None, None) - head_dim = self.past_key.size(-1) - seq_length = self.past_key.size(-2) - self.reorder(self.past_key, beam_idx, seq_length, head_dim) - self.reorder(self.past_value, beam_idx, seq_length, head_dim) - return (self.past_key.shape, self.past_value.shape) + head_dim = self.k_cache.cache.size(-1) + seq_length = self.k_cache.cache.size(-2) + self.reorder(self.k_cache.cache, beam_idx, seq_length, head_dim) + self.reorder(self.v_cache.cache, beam_idx, seq_length, head_dim) + return (self.k_cache.cache.shape, self.v_cache.cache.shape) def forward( self, @@ -248,27 +247,28 @@ def forward( cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - if past_key_value is not None or reuse_cache: - if reuse_cache: - past_key = self.past_key - past_value = self.past_value - else: - past_key = past_key_value[0] - past_value = past_key_value[1] - key_states = update(past_key, key_states, 2, token_idx) - value_states = update(past_value, value_states, 2, token_idx) if use_cache: + # reuse k, v, self_attention if reuse_cache: - past_key_value = (key_states.contiguous().shape, value_states.contiguous().shape) + key_states = self.k_cache(key_states, 2, token_idx) + value_states = self.v_cache(value_states, 2, token_idx) + past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape()) else: - past_key_value = (key_states.contiguous(), value_states.contiguous()) + if past_key_value is None: + past_key = torch.zeros(key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device) + past_value = torch.zeros(key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device) + past_key_value = (past_key, past_value) + key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) + value_states = self.v_cache.update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len) + + if cache_idx is not None and q_len == 1: + key_states = key_states[:, :, :cache_idx, :] + value_states = value_states[:, :, :cache_idx, :] + if attention_mask is not None: + attention_mask = attention_mask[:, :, :, :cache_idx] + kv_seq_len = key_states.shape[-2] else: past_key_value = None - if cache_idx is not None and q_len == 1: - key_states = key_states[:, :, :cache_idx, :] - value_states = value_states[:, :, :cache_idx, :] - attention_mask = attention_mask[:, :, :, :cache_idx] - kv_seq_len = key_states.shape[-2] # repeat k/v heads if n_kv_heads < n_heads query_states, key_states, value_states, attention_mask = gaudi_mistral_repeat_kv( From 0a325e9bc9289f16f103bfeb4b8bc8a8456095a7 Mon Sep 17 00:00:00 2001 From: Shiv Kaul Date: Mon, 8 Apr 2024 18:29:22 +0000 Subject: [PATCH 3/4] include max/inq sequence length --- .../models/mistral/modeling_mistral.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/optimum/habana/transformers/models/mistral/modeling_mistral.py b/optimum/habana/transformers/models/mistral/modeling_mistral.py index e6e2f69efd..386baed5cd 100644 --- a/optimum/habana/transformers/models/mistral/modeling_mistral.py +++ b/optimum/habana/transformers/models/mistral/modeling_mistral.py @@ -171,7 +171,7 @@ def __init__(self, config: MistralConfig, layer_idx: int): #self.past_value = None self.layer_idx = layer_idx - def allocate_kv_cache(self, batch_size, seq_len): + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): cache_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim) device = self.k_proj.weight.device dtype = self.config.torch_dtype @@ -332,8 +332,8 @@ def __init__(self, config: MistralConfig, layer_idx: int): self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - def allocate_kv_cache(self, batch_size, seq_len): - self.self_attn.allocate_kv_cache(batch_size, seq_len) + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + self.self_attn.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) def reorder_kv_cache(self, beam_idx: torch.LongTensor): return self.self_attn.reorder_kv_cache(beam_idx) @@ -402,9 +402,9 @@ def forward( class GaudiMistralModel(MistralModel): - def allocate_kv_cache(self, batch_size, seq_len): + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): for layer in self.layers: - layer.allocate_kv_cache(batch_size, seq_len) + layer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) def reorder_kv_cache(self, beam_idx: torch.LongTensor): return tuple(layer.reorder_kv_cache(beam_idx) for layer in self.layers) @@ -570,8 +570,8 @@ def forward( class GaudiMistralForCausalLM(MistralForCausalLM): - def allocate_kv_cache(self, batch_size, seq_len, _): - self.model.allocate_kv_cache(batch_size, seq_len) + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) def reorder_kv_cache(self, beam_idx: torch.LongTensor): return self.model.reorder_kv_cache(beam_idx) From d4666c4083868a1c194f8150c8ab289787407d3a Mon Sep 17 00:00:00 2001 From: Shiv Kaul Date: Mon, 8 Apr 2024 18:57:44 +0000 Subject: [PATCH 4/4] move kv cache to modeling_all_models --- .../models/mistral/modeling_mistral.py | 41 +------------------ .../models/modeling_all_models.py | 41 +++++++++++++++++++ 2 files changed, 42 insertions(+), 40 deletions(-) diff --git a/optimum/habana/transformers/models/mistral/modeling_mistral.py b/optimum/habana/transformers/models/mistral/modeling_mistral.py index 386baed5cd..03ce7a4984 100644 --- a/optimum/habana/transformers/models/mistral/modeling_mistral.py +++ b/optimum/habana/transformers/models/mistral/modeling_mistral.py @@ -40,7 +40,7 @@ apply_rotary_pos_emb, ) from transformers.utils import logging - +from optimum.habana.transformers.models.modeling_all_models import KVCache from ...modeling_attn_mask_utils import ( _gaudi_prepare_4d_causal_attention_mask, ) @@ -54,45 +54,6 @@ logger = logging.get_logger(__name__) -class KVCache(torch.nn.Module): - def __init__(self): - super(KVCache, self).__init__() - self.cache = None - self.inp_seq_len = -1 - - def allocate(self, inp_seq_len, dtype, device, shape): - if self.cache is None or self.cache.shape != shape: - self.inp_seq_len = inp_seq_len - self.cache = torch.zeros(shape, dtype=dtype, device=device) - else: - assert ( - self.inp_seq_len == inp_seq_len - ), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}" - self.cache.fill_(0) - - def update(self, prev, cur, dim, idx, inp_seq_len): - orig_cur = cur - if prev.shape == cur.shape: - prev.copy_(cur) - return orig_cur - if cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]: - # Initialize - prev[:, :, :inp_seq_len, :].copy_(cur) - return orig_cur - assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" - if idx is not None: - prev.index_copy_(dim, idx - 1, cur) - return prev - else: - return torch.cat((prev, cur), dim=dim) - - def get_shape(self): - if self.cache is None: - return None - return self.cache.shape - - def forward(self, cur, dim, idx): - return self.update(self.cache, cur, dim, idx, self.inp_seq_len) # Copied from transformers.models.llama.modeling_llama.repeat_kv def gaudi_mistral_repeat_kv( diff --git a/optimum/habana/transformers/models/modeling_all_models.py b/optimum/habana/transformers/models/modeling_all_models.py index 98b3d53a1f..88c3fa81b2 100644 --- a/optimum/habana/transformers/models/modeling_all_models.py +++ b/optimum/habana/transformers/models/modeling_all_models.py @@ -175,3 +175,44 @@ def all_reduce(self, input): def post_all_reduce(self, input): output = input + self.bias if (self.bias is not None) else input return output + + +class KVCache(torch.nn.Module): + def __init__(self): + super(KVCache, self).__init__() + self.cache = None + self.inp_seq_len = -1 + + def allocate(self, inp_seq_len, dtype, device, shape): + if self.cache is None or self.cache.shape != shape: + self.inp_seq_len = inp_seq_len + self.cache = torch.zeros(shape, dtype=dtype, device=device) + else: + assert ( + self.inp_seq_len == inp_seq_len + ), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}" + self.cache.fill_(0) + + def update(self, prev, cur, dim, idx, inp_seq_len): + orig_cur = cur + if prev.shape == cur.shape: + prev.copy_(cur) + return orig_cur + if cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]: + # Initialize + prev[:, :, :inp_seq_len, :].copy_(cur) + return orig_cur + assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" + if idx is not None: + prev.index_copy_(dim, idx - 1, cur) + return prev + else: + return torch.cat((prev, cur), dim=dim) + + def get_shape(self): + if self.cache is None: + return None + return self.cache.shape + + def forward(self, cur, dim, idx): + return self.update(self.cache, cur, dim, idx, self.inp_seq_len)