diff --git a/examples/models/llama2/builder.py b/examples/models/llama2/builder.py index 50a70aa4a04..c8d949eb6f2 100644 --- a/examples/models/llama2/builder.py +++ b/examples/models/llama2/builder.py @@ -15,6 +15,14 @@ from typing import Any, Callable, List, Optional import torch + +try: + from ...portable.utils import export_to_edge, save_pte_program +except ImportError: + # Workaround to bypass the different paths between executorch pip package and directly python call + # TODO: remove this try catch workaround and have a standard wa to import portable.utils + # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `examples.portable.utils`. + from examples.portable.utils import export_to_edge, save_pte_program from executorch.backends.transforms.duplicate_dynamic_quant_chain import ( DuplicateDynamicQuantChainPass, ) @@ -33,7 +41,6 @@ from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer from torch.nn.attention import SDPBackend -from ...portable.utils import export_to_edge, save_pte_program from ..model_factory import EagerModelFactory FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" diff --git a/examples/models/llama2/eval_llama_lib.py b/examples/models/llama2/eval_llama_lib.py index 949d127936d..c6b1d844cb6 100644 --- a/examples/models/llama2/eval_llama_lib.py +++ b/examples/models/llama2/eval_llama_lib.py @@ -9,8 +9,8 @@ from typing import Optional, Union -import lm_eval import torch +from executorch.examples.models.llama2.evaluate import EagerEvalWrapper, evaluate_model from executorch.examples.models.llama2.export_llama_lib import ( get_quantizer_and_quant_params, ) @@ -20,11 +20,6 @@ ) from lm_eval.api.model import LM -from lm_eval.evaluator import evaluate -from lm_eval.models.huggingface import HFLM as eval_wrapper -from lm_eval.tasks import get_task_dict - -from torch import nn from .builder import LlamaEdgeManager from .export_llama_lib import ( @@ -33,75 +28,6 @@ ) -class EagerEvalWrapper(eval_wrapper): - """ - A wrapper class based on GPTFast, providing integration with the lm-evaluation-harness library. - """ - - def __init__( - self, - model: nn.Module, - tokenizer: Union[SentencePieceTokenizer, Tiktoken], - max_seq_length: Optional[int] = None, - use_kv_cache: bool = False, - ): - device = "cuda" if torch.cuda.is_available() else "cpu" - super().__init__(device=device) - self._model = model - self._tokenizer = tokenizer - self._device = torch.device(device) - self._max_seq_length = 2048 if max_seq_length is None else max_seq_length - self._use_kv_cache = use_kv_cache - - @property - def eot_token_id(self): - return self._tokenizer.eos_id - - @property - def max_length(self): - return self._max_seq_length - - @property - def max_gen_toks(self): - return 50 - - @property - def batch_size(self): - return 1 - - @property - def device(self): - return self._device - - def tok_encode(self, string: str, **kwargs): - tokens = self._tokenizer.encode(string, bos=True, eos=False) - encoded = torch.tensor(tokens, dtype=torch.int, device=self.device) - # encoded is a pytorch tensor, but some internal logic in the - # eval harness expects it to be a list instead - # TODO: verify this for multi-batch as well - encoded = encoded.tolist() - return encoded - - def tok_decode(self, tokens): - decoded = self._tokenizer.decode(tokens) - return decoded - - def _model_call(self, inps): - if self._use_kv_cache: - pos_tensor = torch.arange( - self._max_seq_length, dtype=torch.int64, device=self.device - ) - - # Batch process the whole sequence. - logits = self._model(inps[:, : self._max_seq_length], pos_tensor) - return logits - else: - return self._model(inps) - - def _model_generate(self, context, max_length, eos_token_id): - raise Exception("unimplemented") - - class ETPybindEvalWrapper(EagerEvalWrapper): """ A wrapper class for ExecuTorch py-binded integration with the @@ -165,40 +91,6 @@ def _model_call(self, inps): pass -@torch.no_grad() -def eval( - eval_wrapper: LM, - tasks: Optional[list] = None, - limit: Optional[int] = None, -) -> dict: - """ - Evaluates a language model on a specified task using the lm-evaluation-harness library. - - Args: - eval_wrapper (LM): A LM wrapper class compatible with lm-evaluation-harness evaluation - task (str): The name of the evaluation task to perform. - limit (Optional[int]): The maximum number of samples to evaluate (None for all available). - - Returns: - eval_results (dict): A dictionary of evaluation results for the specified task(s). - """ - - if tasks is None: - tasks = ["wikitext"] - - if "hendrycks_test" in tasks: - tasks.remove("hendrycks_test") - tasks += list(lm_eval.tasks.hendrycks_test.create_all_tasks().keys()) - task_dict = get_task_dict(tasks) - - eval_results = evaluate( - eval_wrapper, - task_dict, - limit=limit, - ) - return eval_results - - def gen_eval_wrapper( model_name: str, args: argparse.ArgumentParser, @@ -307,7 +199,7 @@ def eval_llama( eval_wrapper = gen_eval_wrapper(model_name, args) # Evaluate the model - eval_results = eval( + eval_results = evaluate_model( eval_wrapper, args.tasks, args.limit, diff --git a/examples/models/llama2/evaluate/__init__.py b/examples/models/llama2/evaluate/__init__.py new file mode 100644 index 00000000000..c7f7d2be6e5 --- /dev/null +++ b/examples/models/llama2/evaluate/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .eager_eval import EagerEvalWrapper, evaluate_model + +__all__ = [ + "evaluate_model", + "EagerEvalWrapper", +] diff --git a/examples/models/llama2/evaluate/eager_eval.py b/examples/models/llama2/evaluate/eager_eval.py new file mode 100644 index 00000000000..36db6ff9c53 --- /dev/null +++ b/examples/models/llama2/evaluate/eager_eval.py @@ -0,0 +1,125 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Optional, Union + +import lm_eval +import torch +from executorch.examples.models.llama2.tokenizer.tiktoken import Tokenizer as Tiktoken +from executorch.examples.models.llama2.tokenizer.tokenizer import ( + Tokenizer as SentencePieceTokenizer, +) + +from lm_eval.api.model import LM +from lm_eval.evaluator import evaluate +from lm_eval.models.huggingface import HFLM as eval_wrapper +from lm_eval.tasks import get_task_dict + +from torch import nn + + +class EagerEvalWrapper(eval_wrapper): + """ + A wrapper class based on GPTFast, providing integration with the lm-evaluation-harness library. + """ + + def __init__( + self, + model: nn.Module, + tokenizer: Union[SentencePieceTokenizer, Tiktoken], + max_seq_length: Optional[int] = None, + use_kv_cache: bool = False, + ): + device = "cuda" if torch.cuda.is_available() else "cpu" + super().__init__(device=device) + self._model = model + self._tokenizer = tokenizer + self._device = torch.device(device) + self._max_seq_length = 2048 if max_seq_length is None else max_seq_length + self._use_kv_cache = use_kv_cache + + @property + def eot_token_id(self): + return self._tokenizer.eos_id + + @property + def max_length(self): + return self._max_seq_length + + @property + def max_gen_toks(self): + return 50 + + @property + def batch_size(self): + return 1 + + @property + def device(self): + return self._device + + def tok_encode(self, string: str, **kwargs): + tokens = self._tokenizer.encode(string, bos=True, eos=False) + encoded = torch.tensor(tokens, dtype=torch.int, device=self.device) + # encoded is a pytorch tensor, but some internal logic in the + # eval harness expects it to be a list instead + # TODO: verify this for multi-batch as well + encoded = encoded.tolist() + return encoded + + def tok_decode(self, tokens): + decoded = self._tokenizer.decode(tokens) + return decoded + + def _model_call(self, inps): + if self._use_kv_cache: + pos_tensor = torch.arange( + self._max_seq_length, dtype=torch.int64, device=self.device + ) + + # Batch process the whole sequence. + logits = self._model(inps[:, : self._max_seq_length], pos_tensor) + return logits + else: + return self._model(inps) + + def _model_generate(self, context, max_length, eos_token_id): + raise Exception("unimplemented") + + +@torch.no_grad() +def evaluate_model( + eval_wrapper: LM, + tasks: Optional[list] = None, + limit: Optional[int] = None, +) -> dict: + """ + Evaluates a language model on a specified task using the lm-evaluation-harness library. + + Args: + eval_wrapper (LM): A LM wrapper class compatible with lm-evaluation-harness evaluation + task (str): The name of the evaluation task to perform. + limit (Optional[int]): The maximum number of samples to evaluate (None for all available). + + Returns: + eval_results (dict): A dictionary of evaluation results for the specified task(s). + """ + + if tasks is None: + tasks = ["wikitext"] + + if "hendrycks_test" in tasks: + tasks.remove("hendrycks_test") + tasks += list(lm_eval.tasks.hendrycks_test.create_all_tasks().keys()) + task_dict = get_task_dict(tasks) + + eval_results = evaluate( + eval_wrapper, + task_dict, + limit=limit, + ) + return eval_results