diff --git a/examples/models/llama2/eval_llama_lib.py b/examples/models/llama2/eval_llama_lib.py index d8c6c7bf1d4..c9faeb556c8 100644 --- a/examples/models/llama2/eval_llama_lib.py +++ b/examples/models/llama2/eval_llama_lib.py @@ -41,6 +41,7 @@ def __init__( model: nn.Module, tokenizer: Union[SentencePieceTokenizer, Tiktoken], max_seq_length: Optional[int] = None, + use_kv_cache: bool = False, ): device = "cuda" if torch.cuda.is_available() else "cpu" super().__init__(device=device) @@ -48,6 +49,7 @@ def __init__( self._tokenizer = tokenizer self._device = torch.device(device) self._max_seq_length = 2048 if max_seq_length is None else max_seq_length + self._use_kv_cache = use_kv_cache @property def eot_token_id(self): @@ -83,7 +85,15 @@ def tok_decode(self, tokens): return decoded def _model_call(self, inps): - return self._model(inps) + if self._use_kv_cache: + result_logits = [] + for pos in range(self._max_seq_length): + pos_tensor = torch.tensor([pos], dtype=torch.int64) + logits = self._model(inps[:, pos : pos + 1], pos_tensor) + result_logits.append(logits) + return torch.cat(result_logits, dim=1) + else: + return self._model(inps) def _model_generate(self, context, max_length, eos_token_id): raise Exception("unimplemented") @@ -107,13 +117,22 @@ def __init__( from executorch.extension.pybindings.portable_lib import _load_for_executorch self._et_model = _load_for_executorch(self._model) + self._use_kv_cache = self._et_model.run_method("use_kv_cache")[0] def _model_call(self, inps): # Given inps (tokens), return the logits from a single forward call # inps: Tensor of shape (1, max_seq_len - 1) - # logits: Tensor of shape (1, max_seq_len - 1, 32000) - result = self._et_model.forward((inps,)) - return result[0] + # logits: Tensor of shape (1, max_seq_len - 1, vocab_size) + if self._use_kv_cache: + result_logits = [] + for pos in range(self._max_seq_length): + pos_tensor = torch.tensor([pos], dtype=torch.int64) + logits = self._et_model.forward((inps[:, pos : pos + 1], pos_tensor)) + result_logits.append(logits[0]) + return torch.cat(result_logits, dim=1) + else: + result = self._et_model.forward((inps,)) + return result[0] class ETRunnerEvalWrapper(GPTFastEvalWrapper): @@ -139,7 +158,7 @@ def _model_call(self, inps): # Example: # inps: Tensor of shape (1, N) - # logits: Tensor of shape (1, N, 32000) + # logits: Tensor of shape (1, N, vocab_size) pass @@ -225,6 +244,7 @@ def gen_eval_wrapper( model=model, tokenizer=tokenizer, max_seq_length=args.max_seq_length, + use_kv_cache=args.use_kv_cache, )