Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 100 additions & 37 deletions examples/text-generation/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,23 @@
###############################################################################

import argparse
from typing import Literal, Optional
import logging
from typing import Literal, Optional, Union

import torch
import torch.nn.functional as F
from lm_eval.api.instance import Instance
from lm_eval.models.huggingface import HFLM, TemplateLM
from lm_eval.models.utils import get_dtype, stop_sequences_criteria

# Local imports
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig


logger = logging.getLogger(__name__)


class HabanaModelAdapter(HFLM):
def __init__(
self,
Expand All @@ -35,10 +43,18 @@ def __init__(
args: argparse.Namespace,
options: GenerationConfig,
backend: Literal["default", "causal", "seq2seq"] = "default",
truncation: Optional[bool] = False,
logits_cache: bool = True,
max_length: Optional[int] = None,
softmax_dtype: Union[str, torch.dtype, None] = None,
add_bos_token: Optional[bool] = True,
prefix_token_id: Optional[int] = None,
delta: Optional[str] = None,
# end token for thinking, either the string or int token id.
# splits to get response after this token (if provided).
think_end_token: Optional[Union[str, int]] = None,
enable_thinking: Optional[bool] = None,
chat_template_args: Optional[dict] = None,
**kwargs,
) -> None:
# To skip cuda code of the HFLM init
Expand All @@ -54,11 +70,32 @@ def __init__(
self.peft = args.peft_model
self.delta = delta
self.custom_prefix_token_id = prefix_token_id
if isinstance(think_end_token, str) and think_end_token.isdigit():
self.think_end_token = int(think_end_token)
else:
self.think_end_token = think_end_token

self.chat_template_args = chat_template_args or {}
if enable_thinking is not None:
self.chat_template_args.update({"enable_thinking": enable_thinking})

# determine which of 'causal' and 'seq2seq' backends to use for HF models
self._get_backend(config=self._config, backend=backend, trust_remote_code=args.trust_remote_code)
self.truncation = truncation
self.logits_cache = logits_cache
self.add_bos_token = add_bos_token
self._max_length = options.max_length
self._max_length = max_length
self.softmax_dtype = get_dtype(softmax_dtype) if softmax_dtype is not None else None
self.hpu_graphs = args.use_hpu_graphs
self.use_lazy_mode = True
if args.torch_compile:
self.use_lazy_mode = False
self.vocab_size = self._model.config.vocab_size
if "gemma" in getattr(self._config, "model_type", ""):
self.add_bos_token = True
logger.info(
f"Model type is '{self._config.model_type}', part of the Gemma family--a BOS token will be used as Gemma underperforms without it."
)
self.batch_size_per_gpu = int(args.batch_size)
self.revision = args.model_revision
self.model_inputs = {"use_cache": self.options.use_cache}
Expand Down Expand Up @@ -119,16 +156,27 @@ def eot_token_id(self) -> int:

@property
def max_length(self) -> int:
return self.buckets[-1]
# Legacy
return self._max_length if self._max_length else self.buckets[-1]

@property
def device(self):
# We need to do padding ourselves, otherwise we'll end up with recompilations
# Returning 'cpu' to keep tensors on CPU in lm_eval code
return "cpu"

def find_bucket(self, length: int) -> list[int]:
return [b for b in self.buckets if b >= length][0]
@max_length.setter
def max_length(self, value: int) -> None:
self._max_length = value

def find_bucket(self, length: int, key=lambda b, length: b >= length) -> int:
for b in self.buckets:
if key(b, length):
return b
new_bucket = length
self.buckets.append(new_bucket)
self.buckets.sort()
return new_bucket

def _model_call(self, inps: torch.Tensor) -> torch.Tensor:
bs, seq_length = inps.shape
Expand All @@ -146,38 +194,53 @@ def _model_call(self, inps: torch.Tensor) -> torch.Tensor:
logits = logits.to(torch.float32)
return logits

def get_model_info(self) -> dict:
def generate_until(self, requests: list[Instance], disable_tqdm: bool = False) -> list[str]:
"""
Override to change only max_length property
"""
Patched method to get Hugging Face model information for experiment reproducibility.
source: https://github.com/EleutherAI/lm-evaluation-harness/blob/v0.4.7/lm_eval/models/huggingface.py/#L1375
Remove from SynapseAI 1.21
legacy_max_length = self.max_length
self.max_length = super().max_length
# Call the parent class's implementation for the unchanged parts
res = super().generate_until(requests, disable_tqdm)
self.max_length = legacy_max_length
return res

def _model_generate(self, context, max_length, stop, **generation_kwargs):
"""
Patched method
source: https://github.com/EleutherAI/lm-evaluation-harness/blob/v0.4.7/lm_eval/models/huggingface.py/#L858
"""

def get_model_num_params(model) -> int:
if hasattr(model, "num_parameters"):
return model.num_parameters()
if hasattr(model, "parameters"):
return sum(p.numel() for p in model.parameters())
else:
return -1

def get_model_dtype(model) -> str:
if hasattr(model, "dtype"):
return model.dtype
else:
return ""

def get_model_sha(pretrained: str, revision: str) -> str:
return ""

model_info = {
"model_num_parameters": get_model_num_params(self._model),
"model_dtype": get_model_dtype(self._model),
"model_revision": self.revision,
"model_sha": get_model_sha(self.pretrained, self.revision),
}
if self.peft:
model_info["peft_sha"] = get_model_sha(self.peft, self.revision)
if self.delta:
model_info["delta_sha"] = get_model_sha(self.delta, self.revision)
return model_info
# temperature = 0.0 if not set
# if do_sample is false and temp==0.0:
# remove temperature, as do_sample=False takes care of this
# and we don't want a warning from HF
generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
do_sample = generation_kwargs.get("do_sample", None)

# The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
if generation_kwargs.get("temperature") == 0.0 and do_sample is None:
generation_kwargs["do_sample"] = do_sample = False

if do_sample is False and generation_kwargs.get("temperature") == 0.0:
generation_kwargs.pop("temperature")
# build stopping criteria
stopping_criteria = stop_sequences_criteria(self.tokenizer, stop, context.shape[1], context.shape[0])
# to avoid graph recompilation
if self.options.static_shapes:
self.options.bucket_internal = True
_ = self.find_bucket(context.shape[1])
max_gen_toks = max_length - context.shape[1]
# move context & attention_mask to hpu
context = context.to("hpu")
generation_kwargs["attention_mask"] = generation_kwargs["attention_mask"].to("hpu")
return self.model.generate(
input_ids=context,
max_new_tokens=max_gen_toks,
stopping_criteria=stopping_criteria,
pad_token_id=self.tokenizer.pad_token_id,
use_cache=True,
hpu_graphs=self.hpu_graphs,
lazy_mode=self.use_lazy_mode,
**generation_kwargs,
)
6 changes: 4 additions & 2 deletions examples/text-generation/requirements_lm_eval.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
lm-eval==0.4.7
lm-eval==0.4.9.1
datasets==3.6.0
langdetect<=1.0.9
immutabledict<=4.2.1
tiktoken
blobfile
sentencepiece
sentencepiece
114 changes: 107 additions & 7 deletions examples/text-generation/run_lm_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import logging
import multiprocessing as mp
import os
from pathlib import Path
from typing import Union

import psutil

Expand Down Expand Up @@ -53,6 +55,20 @@ def LimitedSpawnPool(_):
mp.Pool = LimitedSpawnPool


def try_parse_json(value: str) -> Union[str, dict, None]:
"""
From https://github.com/EleutherAI/lm-evaluation-harness/blob/v0.4.9.1/lm_eval/__main__.py
"""
if value is None:
return None
try:
return json.loads(value)
except json.JSONDecodeError:
if "{" in value:
raise argparse.ArgumentTypeError(f"Invalid JSON: {value}. Hint: Use double quotes for JSON strings.")
return value


def setup_lm_eval_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="Evaluation script for HPU"
Expand All @@ -75,7 +91,14 @@ def setup_lm_eval_parser():
help="Tasks to run",
default=["hellaswag", "lambada_openai", "piqa", "winogrande"],
)
parser.add_argument("--limit_iters", type=int, help="limit examples to run that many iterations", default=None)
parser.add_argument(
"--limit",
"-L",
type=float,
default=None,
metavar="N|0<N<1",
help="Limit the number of examples per task. If <1, limit is a percentage of the total number of examples.",
)
parser.add_argument(
"--show_config",
action="store_true",
Expand All @@ -94,34 +117,111 @@ def setup_lm_eval_parser():
help="System instruction to be used in the prompt",
)
parser.add_argument("--max_graphs", type=int, help="Maximum number of HPU graphs", default=None)
parser.add_argument(
"--gen_kwargs",
type=try_parse_json,
default=None,
help=(
"Either comma delimited string or JSON formatted arguments for model generation on greedy_until tasks,"
""" e.g. '{"temperature":0.7,"until":["hello"]}' or temperature=0,top_p=0.1."""
),
)
parser.add_argument(
"--num_fewshot",
"-f",
type=int,
default=None,
metavar="N",
help="Number of examples in few-shot context",
)
parser.add_argument(
"--fewshot_as_multiturn",
action="store_true",
default=False,
help="If True, uses the fewshot as a multi-turn conversation",
)
parser.add_argument(
"--metadata",
type=json.loads,
default=None,
help="""JSON string metadata to pass to task configs, for example '{"max_seq_lengths":[4096,8192]}'. Will be merged with model_args. Can also be set in task config.""",
)
parser.add_argument(
"--apply_chat_template",
type=str,
nargs="?",
const=True,
default=False,
help=(
"If True, apply chat template to the prompt. "
"Providing `--apply_chat_template` without an argument will apply the default chat template to the prompt. "
"To apply a specific template from the available list of templates, provide the template name as an argument. "
"E.g. `--apply_chat_template template_name`"
),
)
parser.add_argument(
"--samples",
"-E",
default=None,
type=str,
metavar="/path/to/json",
help='JSON string or path to JSON file containing doc indices of selected examples to test. Format: {"task_name":[indices],...}',
)
parser.add_argument(
"--confirm_run_unsafe_code",
action="store_true",
help="Confirm that you understand the risks of running unsafe code for tasks that require it",
)
args = setup_parser(parser)

return args


def main() -> None:
# Modified based on cli_evaluate function in https://github.com/EleutherAI/lm-evaluation-harness/blob/v0.4.7/lm_eval/__main__.py/#L268
# Modified based on cli_evaluate function in https://github.com/EleutherAI/lm-evaluation-harness/blob/v0.4.9.1/lm_eval/__main__.py#L301
args = setup_lm_eval_parser()
model, _, tokenizer, generation_config = initialize_model(args, logger)

import torch
from lm_eval import evaluator, utils
Comment thread
12010486 marked this conversation as resolved.
from model_adapter import HabanaModelAdapter

max_length = None
metadata = None
if args.metadata:
metadata = args.metadata if isinstance(args.metadata, dict) else utils.sample_parse_args_string(args.metadata)
max_length = args.metadata.get("max_length")

if args.fewshot_as_multiturn and args.apply_chat_template is False:
raise ValueError(
"When `fewshot_as_multiturn` is selected, `apply_chat_template` must be set (either to `True` or to the chosen template name)."
)
if args.samples:
assert args.limit is None, "If --samples is not None, then --limit must be None."
if (samples := Path(args.samples)).is_file():
args.samples = json.loads(samples.read_text())
else:
args.samples = json.loads(args.samples)

with torch.no_grad():
lm = HabanaModelAdapter(tokenizer, model, args, generation_config)
lm = HabanaModelAdapter(tokenizer, model, args, generation_config, max_length=max_length)

from optimum.habana.utils import HabanaGenerationTime, get_hpu_memory_stats

with HabanaGenerationTime() as timer:
with torch.no_grad():
log_samples = args.log_samples
results = evaluator.simple_evaluate(
lm,
tasks=args.tasks,
limit=args.limit_iters,
log_samples=log_samples,
limit=args.limit,
samples=args.samples,
log_samples=args.log_samples,
num_fewshot=args.num_fewshot,
fewshot_as_multiturn=args.fewshot_as_multiturn,
gen_kwargs=args.gen_kwargs,
system_instruction=args.system_instruction,
apply_chat_template=args.apply_chat_template,
metadata=metadata,
confirm_run_unsafe_code=args.confirm_run_unsafe_code,
)
if args.device == "hpu":
import habana_frameworks.torch.hpu as torch_hpu
Expand Down
Loading