Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/text-generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,9 @@ python run_generation.py \
`--bucket_size` option is especially useful when processing an input stream with varying lengths, that is when you have something like `--dataset_name squad --column_name context --max_input_tokens -1`. `--max_input_tokens -1` specifies no truncation of input prompt in the dataset.

Another way to simulate dynamic input is to use `--simulate_dyn_prompt`. For example `--simulate_dyn_prompt 25,35,45` will extend or crop the default prompt (or the prompt passed in using `--prompt`) to sizes 25, 35, and 45, and throughput will be measured for these 3 lengths. If `--simulate_dyn_prompt` is used, the min and max input lengths from it are computed to perform warmup as well. One final optimization that can be used in case of dynamic inputs is `--reduce_recompile`. Thus the suggested configuration to simulate dynamicity after warmup is to use all three arguments: `--simulate_dyn_prompt 25 35 45 --reduce_recompile --bucket_size 30`

While `--bucket_size` works for any model without model file changes, an even more optimized version of bucketing is supported for certain models like Llama. This can be enabled by setting `--bucket_internal` flag (along with `--bucket_size` to specify the bucket size)

### Running with FP8

Llama2-70b and Llama2-7b in FP8 are enabled using the Quantization Toolkit (HQT), which provides model measurement and quantization capabilities in PyTorch.
Expand Down
5 changes: 5 additions & 0 deletions examples/text-generation/run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,11 @@ def setup_parser(parser):
then we use `shape = prompt_length + max_new_tokens`. If a positive number is passed \
we increase the bucket in steps of `bucket_size` instead of allocating to max (`prompt_length + max_new_tokens`).",
)
parser.add_argument(
"--bucket_internal",
action="store_true",
help="Split kv sequence into buckets in decode phase. It improves throughput when max_new_tokens is large.",
)
parser.add_argument(
"--dataset_max_samples",
default=-1,
Expand Down
1 change: 1 addition & 0 deletions examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ def setup_generation_config(args, model, tokenizer):
generation_config.use_cache = args.use_kv_cache
generation_config.static_shapes = is_optimized
generation_config.bucket_size = args.bucket_size if is_optimized else -1
generation_config.bucket_internal = args.bucket_internal
generation_config.do_sample = args.do_sample
generation_config.num_beams = args.num_beams
generation_config.bad_words_ids = bad_words_ids
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class GaudiGenerationConfig(GenerationConfig):
If negative (default=-1) pad to max if `static_shapes` is set. Else start with
`shape = bucket_size * ceil(prompt_len/bucket_size)` and then grow space by `bucket_size` when needed.
Only active if `static_shapes` is used. Can't be used with `reuse_cache`.
bucket_internal (`bool`, *optional*):
Split kv sequence into buckets in decode phase. It improves throughput when max_new_tokens is large.
kv_cache_fp8 (`bool`, *optional*):
Store kv-cache in float8 when kv-cache is used
use_flash_attention (`bool`, *optional*):
Expand All @@ -44,6 +46,7 @@ def __init__(self, **kwargs):
self.limit_hpu_graphs = kwargs.get("limit_hpu_graphs", None)
self.reuse_cache = kwargs.get("reuse_cache", None)
self.bucket_size = kwargs.get("bucket_size", -1)
self.bucket_internal = kwargs.get("bucket_internal", None)
Comment thread
regisss marked this conversation as resolved.
self.reduce_recompile = kwargs.get("reduce_recompile", None)
self.kv_cache_fp8 = kwargs.get("kv_cache_fp8", None)
self.use_flash_attention = kwargs.get("use_flash_attention", None)
Expand Down
52 changes: 39 additions & 13 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,19 +550,29 @@ def generate(
inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id
)

is_greedy_or_beam_and_bucket = generation_config.bucket_size > 0 and (
self._get_generation_mode(generation_config, assistant_model) == GenerationMode.GREEDY_SEARCH
or self._get_generation_mode(generation_config, assistant_model) == GenerationMode.BEAM_SEARCH
is_greedy_or_beam_and_bucket = (
not generation_config.bucket_internal
and generation_config.bucket_size > 0
and (
self._get_generation_mode(generation_config, assistant_model) == GenerationMode.GREEDY_SEARCH
or self._get_generation_mode(generation_config, assistant_model) == GenerationMode.BEAM_SEARCH
)
)
model_kwargs["bucket_size"] = generation_config.bucket_size if generation_config.static_shapes else -1
model_kwargs["bucket_internal"] = generation_config.bucket_internal
model_kwargs["reduce_recompile"] = (
generation_config.reduce_recompile if generation_config.reduce_recompile is not None else False
)
if model_kwargs["reduce_recompile"]:
assert generation_config.bucket_size
if generation_config.reuse_cache:
assert self.config.model_type in ["llama"], "reuse_cache only supported by llama at the moment"
assert generation_config.bucket_size <= 0, "reuse_cache and bucketing flags set together"
if not generation_config.bucket_internal:
assert (
generation_config.bucket_size <= 0
), "please set bucket_internal along with reuse_cache and bucket_size"
else:
assert generation_config.bucket_size >= 0, "please set valid bucket_size to use bucket_internal"

if generation_config.static_shapes:
# Pad inputs to have static shapes during generation, this gives better performance than dynamic shapes on HPUs
Expand Down Expand Up @@ -690,6 +700,8 @@ def generate(
token_idx,
generation_config.kv_cache_fp8,
)
model_kwargs["kv_cache_len"] = calculated_max_length

if self.config.model_type in ["llama"]:
if self.config.max_position_embeddings < calculated_max_length:
unwrap_deepspeed_model(self).update_sincos_cache(seq_len=calculated_max_length)
Expand Down Expand Up @@ -1329,12 +1341,15 @@ def greedy_search(
this_peer_finished = False # used by synced_gpus only
bucket_size = model_kwargs.get("bucket_size", -1)
reduce_recompile = model_kwargs.get("reduce_recompile", False)
prev_idx = None # avoiding calculate cache_idx when its value is not changing
bucket_internal = model_kwargs.get("bucket_internal", None)

prompt_len = input_ids.shape[-1]
if bucket_size >= 0:
inc = iter(incrementor(bucket_size, prompt_len))
if bucket_size > 0:
assert "position_ids" not in model_kwargs, "Untested path"
if not bucket_internal:
if bucket_size >= 0:
inc = iter(incrementor(bucket_size, prompt_len))
if bucket_size > 0:
assert "position_ids" not in model_kwargs, "Untested path"
cur_len = prompt_len
token_idx = model_kwargs.get("token_idx", None)
if token_idx is not None:
Expand All @@ -1355,11 +1370,22 @@ def greedy_search(
break

if bucket_size > 0:
# it will not have been padded if bucket_size > 0
params = next(inc)
input_ids, model_kwargs = self.update_model_kwargs_for_bucketing(
params, input_ids, model_kwargs, pad_token_id, bucket_size, reduce_recompile
)
if not bucket_internal:
# it will not have been padded if bucket_size > 0
params = next(inc)
input_ids, model_kwargs = self.update_model_kwargs_for_bucketing(
params, input_ids, model_kwargs, pad_token_id, bucket_size, reduce_recompile
)
else:
# Calculate slice idx for kv cache. Breaking down the kv cache in the attention block helps to reduce computation time.
if model_kwargs.get("token_idx") <= (model_kwargs["kv_cache_len"] // bucket_size) * bucket_size:
idx = torch.div(model_kwargs.get("token_idx") - 1, bucket_size, rounding_mode="floor")
if idx != prev_idx:
cache_idx = (idx.item() + 1) * bucket_size
model_kwargs["cache_idx"] = cache_idx
prev_idx = idx
else:
model_kwargs["cache_idx"] = model_kwargs["kv_cache_len"]

# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
Expand Down
17 changes: 17 additions & 0 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def pre_attn_forward(
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
cache_idx: int = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Expand Down Expand Up @@ -281,6 +282,12 @@ def pre_attn_forward(
key_states = update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len)
value_states = update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len)

if cache_idx is not None and q_len == 1:
key_states = key_states[:, :, :cache_idx, :]
value_states = value_states[:, :, :cache_idx, :]
attention_mask = attention_mask[:, :, :, :cache_idx]
kv_seq_len = key_states.shape[-2]

if use_cache:
if reuse_cache:
past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape())
Expand Down Expand Up @@ -445,6 +452,7 @@ def forward(
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
cache_idx: int = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Expand Down Expand Up @@ -474,6 +482,7 @@ def forward(
reuse_cache,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
cache_idx=cache_idx,
**kwargs,
)
self.self_attn.attention_all_reduce(output_pre_attn)
Expand Down Expand Up @@ -503,6 +512,7 @@ def pre_attn(
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
cache_idx: int = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
hidden_states = self.input_layernorm(hidden_states)
output_attn, attn_weights, present_key_value = self.self_attn.pre_attn_forward(
Expand All @@ -517,6 +527,7 @@ def pre_attn(
reuse_cache,
use_flash_attention,
flash_attention_recompute,
cache_idx=cache_idx,
)
return output_attn, attn_weights, present_key_value

Expand Down Expand Up @@ -565,6 +576,7 @@ def forward(
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
cache_idx: int = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
"""
Copied from LlamaModel.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
Expand Down Expand Up @@ -681,6 +693,7 @@ def forward(
reuse_cache=reuse_cache,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
cache_idx=cache_idx,
)

hidden_states = layer_outputs[0]
Expand Down Expand Up @@ -728,6 +741,7 @@ class GaudiLlamaForCausalLM(LlamaForCausalLM):

def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, kv_cache_fp8):
self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len, kv_cache_fp8)
self.kv_cache_len = max_seq_len

def reorder_kv_cache(self, beam_idx: torch.LongTensor):
return self.model.reorder_kv_cache(beam_idx)
Expand All @@ -753,6 +767,7 @@ def forward(
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
cache_idx: int = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand All @@ -775,6 +790,7 @@ def forward(
reuse_cache=reuse_cache,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
cache_idx=cache_idx,
)
hidden_states = outputs[0]
_, seq_len, _ = hidden_states.shape
Expand Down Expand Up @@ -886,6 +902,7 @@ def prepare_inputs_for_generation(
"reuse_cache": reuse_cache,
"use_flash_attention": kwargs.get("use_flash_attention"),
"flash_attention_recompute": kwargs.get("flash_attention_recompute"),
"cache_idx": kwargs.get("cache_idx"),
}
)
return model_inputs
Expand Down