Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 17 additions & 14 deletions tests/example_diff/run_generation.txt
Original file line number Diff line number Diff line change
Expand Up @@ -498,25 +498,28 @@
< help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
---
> help="Whether to enable Habana Flash Attention, provided that the model supports it.",
335,336d233
335,336c234,235
< parser.add_argument("--jit", action="store_true", help="Whether or not to use jit trace to accelerate inference")
< args = parser.parse_args()
338,339c235
---
> parser.add_argument("--temperature", default=1.0, type=float, help="Temperature value for text generation")
> parser.add_argument("--top_p", default=1.0, type=float, help="Top_p value for generating text via sampling")
338,339c237
< # Initialize the distributed state.
< distributed_state = PartialState(cpu=args.use_cpu)
---
> args = parser.parse_args()
341c237,238
341c239,240
< logger.warning(f"device: {distributed_state.device}, 16-bits inference: {args.fp16}")
---
> if not args.use_hpu_graphs:
> args.limit_hpu_graphs = False
343,344c240
343,344c242
< if args.seed is not None:
< set_seed(args.seed)
---
> return args
346,373d241
346,373d243
< # Initialize the model and tokenizer
< try:
< args.model_type = args.model_type.lower()
Expand Down Expand Up @@ -545,7 +548,7 @@
< if requires_preprocessing:
< prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type)
< preprocessed_prompt_text = prepare_input(args, model, tokenizer, prompt_text)
375,378c243,246
375,378c245,248
< if model.__class__.__name__ in ["TransfoXLLMHeadModel"]:
< tokenizer_kwargs = {"add_space_before_punct_symbol": True}
< else:
Expand All @@ -555,7 +558,7 @@
> parser = argparse.ArgumentParser()
> args = setup_parser(parser)
> model, tokenizer, generation_config = initialize_model(args, logger)
380,386c248
380,386c250
< encoded_prompt = tokenizer.encode(
< preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", **tokenizer_kwargs
< )
Expand All @@ -565,7 +568,7 @@
< encoded_prompt = encoded_prompt.to(distributed_state.device)
---
> import habana_frameworks.torch.hpu as torch_hpu
388,389c250,393
388,389c252,395
< if encoded_prompt.size()[-1] == 0:
< input_ids = None
---
Expand Down Expand Up @@ -713,7 +716,7 @@
> print(f"Graph compilation duration = {compilation_duration} seconds")
> print(separator)
> print()
391c395,412
391c397,414
< input_ids = encoded_prompt
---
> # Downloading and loading a dataset from the hub.
Expand All @@ -734,7 +737,7 @@
> .shuffle()
> .select(range(args.dataset_max_samples if args.dataset_max_samples > 0 else (raw_dataset[split]).num_rows))
> )
393,399c414,421
393,399c416,423
< if args.jit:
< jit_input_texts = ["enable jit"]
< jit_inputs = prepare_jit_inputs(jit_input_texts, model, tokenizer)
Expand All @@ -751,7 +754,7 @@
> logger.info(
> f"No column name was given so automatically choosing '{column_name}' for prompts. If you would like to use another column of the dataset, you can set the argument `--column_name`."
> )
401,439c423,443
401,439c425,445
< sig = inspect.signature(model.__call__)
< jit_inputs = tuple(jit_inputs[key] for key in sig.parameters if jit_inputs.get(key, None) is not None)
< traced_model = torch.jit.trace(model, jit_inputs, strict=False)
Expand Down Expand Up @@ -813,7 +816,7 @@
> preprocess_function,
> batched=True,
> desc="Running tokenizer on dataset",
440a445,522
440a447,524
> # After tokenization, we can remove the column of interest
> raw_dataset = raw_dataset.remove_columns([column_name])
> raw_dataset.set_format(type="torch")
Expand Down Expand Up @@ -892,13 +895,13 @@
> )
> print(separator)
> t_end = time.time()
442,443c524,525
442,443c526,527
< generated_sequences.append(total_sequence)
< print(total_sequence)
---
> throughput = total_new_tokens_generated / duration
> # Print Stats
445c527,539
445c529,541
< return generated_sequences
---
> stats = f"Throughput (including tokenization) = {throughput} tokens/second"
Expand Down