diff --git a/examples/models/llama2/eval_llama_lib.py b/examples/models/llama2/eval_llama_lib.py index c9650531705..8bb653739b3 100644 --- a/examples/models/llama2/eval_llama_lib.py +++ b/examples/models/llama2/eval_llama_lib.py @@ -6,16 +6,22 @@ import argparse -from typing import Optional + +from typing import Optional, Union import lm_eval import torch +from executorch.examples.models.llama2.tokenizer.tiktoken import Tokenizer as Tiktoken +from executorch.examples.models.llama2.tokenizer.tokenizer import ( + Tokenizer as SentencePieceTokenizer, +) + from lm_eval.api.model import LM from lm_eval.evaluator import evaluate from lm_eval.models.huggingface import HFLM as eval_wrapper from lm_eval.tasks import get_task_dict -from sentencepiece import SentencePieceProcessor + from torch import nn from .builder import LlamaEdgeManager @@ -33,7 +39,7 @@ class GPTFastEvalWrapper(eval_wrapper): def __init__( self, model: nn.Module, - tokenizer: SentencePieceProcessor, + tokenizer: Union[SentencePieceTokenizer, Tiktoken], max_seq_length: Optional[int] = None, ): super().__init__() @@ -46,7 +52,7 @@ def __init__( @property def eot_token_id(self): - return self._tokenizer.eos_id() + return self._tokenizer.eos_id @property def max_length(self): @@ -65,7 +71,7 @@ def device(self): return self._device def tok_encode(self, string: str, **kwargs): - tokens = [self._tokenizer.bos_id()] + self._tokenizer.encode(string) + tokens = self._tokenizer.encode(string, bos=True, eos=False) encoded = torch.tensor(tokens, dtype=torch.int, device=self.device) # encoded is a pytorch tensor, but some internal logic in the # eval harness expects it to be a list instead @@ -93,7 +99,7 @@ class ETEagerEvalWrapper(GPTFastEvalWrapper): def __init__( self, model: str, - tokenizer: SentencePieceProcessor, + tokenizer: Union[SentencePieceTokenizer, Tiktoken], max_seq_length: Optional[int] = None, ): super().__init__(None, tokenizer, max_seq_length) @@ -120,7 +126,7 @@ class ETRunnerEvalWrapper(GPTFastEvalWrapper): def __init__( self, model: str, - tokenizer: SentencePieceProcessor, + tokenizer: Union[SentencePieceTokenizer, Tiktoken], tokenizer_bin: str, max_seq_length: Optional[int] = None, ): @@ -183,7 +189,11 @@ def gen_eval_wrapper( Returns: eval_wrapper (LM): A wrapper interface for the lm-evaluation-harness library. """ - tokenizer = SentencePieceProcessor(model_file=str(args.tokenizer_path)) + try: + tokenizer = SentencePieceTokenizer(model_path=str(args.tokenizer_path)) + except Exception: + print("Using Tiktokenizer") + tokenizer = Tiktoken(model_path=str(args.tokenizer_path)) # ExecuTorch Binary Evaluation if (model := args.pte) is not None: