Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 71 additions & 54 deletions tests/example_diff/run_generation.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,20 @@
< import inspect
---
> import json
24c24,27
24c24,28
< from typing import Tuple
---
> import math
> import os
> import time
> from itertools import cycle
> from pathlib import Path
27,28c30
27,28c31
< from accelerate import PartialState
< from accelerate.utils import set_seed
---
> from utils import adjust_batch, count_hpu_graphs, initialize_model
30,52c32
30,52c33
< from transformers import (
< AutoTokenizer,
< BloomForCausalLM,
Expand All @@ -47,7 +48,7 @@
< from transformers.modeling_outputs import CausalLMOutputWithPast
---
> from optimum.habana.utils import get_hpu_memory_stats
62,273d41
62,273d42
< MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
<
< MODEL_CLASSES = {
Expand Down Expand Up @@ -260,7 +261,7 @@
< return self._default.prepare_inputs_for_generation(
< input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, **kwargs
< )
275,287c43,45
275,287c44,46
< def _reorder_cache(
< self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
< ) -> Tuple[Tuple[torch.Tensor]]:
Expand All @@ -278,15 +279,15 @@
> def setup_parser(parser):
> # Arguments management
> parser.add_argument("--device", "-d", type=str, choices=["hpu"], help="Device to run", default="hpu")
289c47
289c48
< "--model_type",
---
> "--model_name_or_path",
293c51
293c52
< help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
---
> help="Path to pre-trained model (on the HF Hub or locally).",
296c54,82
296c55,83
< "--model_name_or_path",
---
> "--bf16",
Expand Down Expand Up @@ -318,18 +319,18 @@
> )
> parser.add_argument(
> "--dataset_name",
299,300c85
299,300c86
< required=True,
< help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
---
> help="Optional argument if you want to assess your model on a given dataset of the HF Hub.",
302,306d86
302,306d87
<
< parser.add_argument("--prompt", type=str, default="")
< parser.add_argument("--length", type=int, default=20)
< parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped")
<
308,311c88,91
308,311c89,92
< "--temperature",
< type=float,
< default=1.0,
Expand All @@ -339,7 +340,7 @@
> default=None,
> type=str,
> help="If `--dataset_name` was given, this will be the name of the column to use as prompts for generation.",
314c94,220
314c95,216
< "--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2"
---
> "--do_sample",
Expand Down Expand Up @@ -452,11 +453,6 @@
> action="store_true",
> help="Whether to reuse key/value cache for decoding. It should save memory.",
> )
> parser.add_argument(
> "--skip_hash_with_views",
> action="store_true",
> help="Whether to skip hash with views for HPU graphs. When skip_hash_with_views is not used, the input to HPU graphs includes both view and base tensors.",
> )
> parser.add_argument("--verbose_workers", action="store_true", help="Enable output from non-master workers")
> parser.add_argument(
> "--simulate_dyn_prompt",
Expand All @@ -469,57 +465,63 @@
> "--reduce_recompile",
> action="store_true",
> help="Preprocess on cpu, and some other optimizations. Useful to prevent recompilations when using dynamic prompts (simulate_dyn_prompt)",
316,321d221
316,321d217
< parser.add_argument("--k", type=int, default=0)
< parser.add_argument("--p", type=float, default=0.9)
<
< parser.add_argument("--prefix", type=str, default="", help="Text added prior to input.")
< parser.add_argument("--padding_text", type=str, default="", help="Deprecated, the use of `--prefix` is preferred.")
< parser.add_argument("--xlm_language", type=str, default="", help="Optional language when used with the XLM model.")
323d222
323d218
< parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
325c224
325c220
< "--use_cpu",
---
> "--kv_cache_fp8",
327c226
327c222
< help="Whether or not to use cpu. If set to False, " "we will use gpu/npu or mps device if available",
---
> help="Store kv-cache in float8 when kv-cache is used",
329c228
> help="Store kv-cache in float8 when kv-cache is used. Can't use this argument together with QUANT_CONFIG env var",
329c224
< parser.add_argument("--num_return_sequences", type=int, default=1, help="The number of samples to generate.")
---
> parser.add_argument("--fp8", action="store_true", help="Enable Quantization to fp8")
331c230
331c226
< "--fp16",
---
> "--use_flash_attention",
333c232
333c228
< help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
---
> help="Whether to enable Habana Flash Attention, provided that the model supports it.",
335,336c234,235
335,336c230,236
< parser.add_argument("--jit", action="store_true", help="Whether or not to use jit trace to accelerate inference")
< args = parser.parse_args()
---
> parser.add_argument(
> "--torch_compile",
> action="store_true",
> help="Whether to use torch compiled model or not.",
> )
> parser.add_argument("--temperature", default=1.0, type=float, help="Temperature value for text generation")
> parser.add_argument("--top_p", default=1.0, type=float, help="Top_p value for generating text via sampling")
338,339c237
338,339c238
< # Initialize the distributed state.
< distributed_state = PartialState(cpu=args.use_cpu)
---
> args = parser.parse_args()
341c239,240
341c240,241
< logger.warning(f"device: {distributed_state.device}, 16-bits inference: {args.fp16}")
---
> if not args.use_hpu_graphs:
> args.limit_hpu_graphs = False
343,344c242
> if args.torch_compile:
> args.use_hpu_graphs = False
343,344c243,244
< if args.seed is not None:
< set_seed(args.seed)
---
> return args
346,373d243
> if not args.use_hpu_graphs:
> args.limit_hpu_graphs = False
346,373c246,251
< # Initialize the model and tokenizer
< try:
< args.model_type = args.model_type.lower()
Expand Down Expand Up @@ -548,17 +550,19 @@
< if requires_preprocessing:
< prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type)
< preprocessed_prompt_text = prepare_input(args, model, tokenizer, prompt_text)
375,378c245,248
---
> args.quant_config = os.getenv("QUANT_CONFIG", "")
> if args.quant_config and args.kv_cache_fp8:
> # can't use both quant_config and kv_cache_fp8, since quant_config may trigger kv cache quantization
> # with habana quantization toolkit
> raise parser.error("Can't use QUANT_CONFIG env var with kv_cache_fp8 argument")
> return args
375,378d252
< if model.__class__.__name__ in ["TransfoXLLMHeadModel"]:
< tokenizer_kwargs = {"add_space_before_punct_symbol": True}
< else:
< tokenizer_kwargs = {}
---
> def main():
> parser = argparse.ArgumentParser()
> args = setup_parser(parser)
> model, tokenizer, generation_config = initialize_model(args, logger)
380,386c250
380,386c254,257
< encoded_prompt = tokenizer.encode(
< preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", **tokenizer_kwargs
< )
Expand All @@ -567,11 +571,20 @@
< encoded_prompt = tokenizer.encode(prefix + prompt_text, add_special_tokens=False, return_tensors="pt")
< encoded_prompt = encoded_prompt.to(distributed_state.device)
---
> import habana_frameworks.torch.hpu as torch_hpu
388,389c252,395
> def main():
> parser = argparse.ArgumentParser()
> args = setup_parser(parser)
> model, tokenizer, generation_config = initialize_model(args, logger)
388,389c259,408
< if encoded_prompt.size()[-1] == 0:
< input_ids = None
---
> use_lazy_mode = True
> if args.torch_compile and model.config.model_type == "llama":
> use_lazy_mode = False
>
> import habana_frameworks.torch.hpu as torch_hpu
>
> if args.dataset_name is None:
> # Benchmark over the prompts below
> if args.prompt:
Expand Down Expand Up @@ -622,7 +635,7 @@
> outputs = model.generate(
> **input_tokens,
> generation_config=generation_config,
> lazy_mode=True,
> lazy_mode=use_lazy_mode,
> hpu_graphs=args.use_hpu_graphs,
> profiling_steps=args.profiling_steps,
> profiling_warmup_steps=args.profiling_warmup_steps,
Expand Down Expand Up @@ -716,7 +729,7 @@
> print(f"Graph compilation duration = {compilation_duration} seconds")
> print(separator)
> print()
391c397,414
391c410,427
< input_ids = encoded_prompt
---
> # Downloading and loading a dataset from the hub.
Expand All @@ -737,7 +750,7 @@
> .shuffle()
> .select(range(args.dataset_max_samples if args.dataset_max_samples > 0 else (raw_dataset[split]).num_rows))
> )
393,399c416,423
393,399c429,436
< if args.jit:
< jit_input_texts = ["enable jit"]
< jit_inputs = prepare_jit_inputs(jit_input_texts, model, tokenizer)
Expand All @@ -754,7 +767,7 @@
> logger.info(
> f"No column name was given so automatically choosing '{column_name}' for prompts. If you would like to use another column of the dataset, you can set the argument `--column_name`."
> )
401,439c425,445
401,439c438,458
< sig = inspect.signature(model.__call__)
< jit_inputs = tuple(jit_inputs[key] for key in sig.parameters if jit_inputs.get(key, None) is not None)
< traced_model = torch.jit.trace(model, jit_inputs, strict=False)
Expand Down Expand Up @@ -816,7 +829,7 @@
> preprocess_function,
> batched=True,
> desc="Running tokenizer on dataset",
440a447,524
440a460,540
> # After tokenization, we can remove the column of interest
> raw_dataset = raw_dataset.remove_columns([column_name])
> raw_dataset.set_format(type="torch")
Expand Down Expand Up @@ -852,7 +865,7 @@
> outputs = model.generate(
> **batch,
> generation_config=generation_config,
> lazy_mode=True,
> lazy_mode=use_lazy_mode,
> hpu_graphs=args.use_hpu_graphs,
> profiling_steps=args.profiling_steps,
> profiling_warmup_steps=args.profiling_warmup_steps,
Expand Down Expand Up @@ -895,14 +908,12 @@
> )
> print(separator)
> t_end = time.time()
442,443c526,527
< generated_sequences.append(total_sequence)
< print(total_sequence)
---
>
> throughput = total_new_tokens_generated / duration
> # Print Stats
445c529,541
< return generated_sequences
442,443c542,556
< generated_sequences.append(total_sequence)
< print(total_sequence)
---
> stats = f"Throughput (including tokenization) = {throughput} tokens/second"
> separator = "-" * len(stats)
Expand All @@ -917,3 +928,9 @@
> if prompt_length > 0:
> print(f"Graph compilation duration = {compilation_duration} seconds")
> print(separator)
> if args.quant_config:
> import habana_quantization_toolkit
445c558
< return generated_sequences
---
> habana_quantization_toolkit.finish_measurements(model)