Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/text-generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,9 @@ python run_generation.py \
`--bucket_size` option is especially useful when processing an input stream with varying lengths, that is when you have something like `--dataset_name squad --column_name context --max_input_tokens -1`. `--max_input_tokens -1` specifies no truncation of input prompt in the dataset.

Another way to simulate dynamic input is to use `--simulate_dyn_prompt`. For example `--simulate_dyn_prompt 25,35,45` will extend or crop the default prompt (or the prompt passed in using `--prompt`) to sizes 25, 35, and 45, and throughput will be measured for these 3 lengths. If `--simulate_dyn_prompt` is used, the min and max input lengths from it are computed to perform warmup as well. One final optimization that can be used in case of dynamic inputs is `--reduce_recompile`. Thus the suggested configuration to simulate dynamicity after warmup is to use all three arguments: `--simulate_dyn_prompt 25 35 45 --reduce_recompile --bucket_size 30`

While `--bucket_size` works for any model without model file changes, an even more optimized version of bucketing is supported for certain models like Llama. This can be enabled by setting `--bucket_internal` flag (along with `--bucket_size` to specify the bucket size)

### Running with FP8

Llama2-70b and Llama2-7b in FP8 are enabled using the Quantization Toolkit (HQT), which provides model measurement and quantization capabilities in PyTorch.
Expand Down
5 changes: 5 additions & 0 deletions examples/text-generation/run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,11 @@ def setup_parser(parser):
then we use `shape = prompt_length + max_new_tokens`. If a positive number is passed \
we increase the bucket in steps of `bucket_size` instead of allocating to max (`prompt_length + max_new_tokens`).",
)
parser.add_argument(
"--bucket_internal",
action="store_true",
help="Split kv sequence into buckets in decode phase. It improves throughput when max_new_tokens is large.",
)
parser.add_argument(
"--dataset_max_samples",
default=-1,
Expand Down
1 change: 1 addition & 0 deletions examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ def setup_generation_config(args, model, tokenizer):
generation_config.use_cache = args.use_kv_cache
generation_config.static_shapes = is_optimized
generation_config.bucket_size = args.bucket_size if is_optimized else -1
generation_config.bucket_internal = args.bucket_internal
generation_config.do_sample = args.do_sample
generation_config.num_beams = args.num_beams
generation_config.bad_words_ids = bad_words_ids
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(self, **kwargs):
self.limit_hpu_graphs = kwargs.get("limit_hpu_graphs", None)
self.reuse_cache = kwargs.get("reuse_cache", None)
self.bucket_size = kwargs.get("bucket_size", -1)
self.bucket_internal = kwargs.get("bucket_internal", None)
self.reduce_recompile = kwargs.get("reduce_recompile", None)
self.kv_cache_fp8 = kwargs.get("kv_cache_fp8", None)
self.use_flash_attention = kwargs.get("use_flash_attention", None)
Expand Down
50 changes: 37 additions & 13 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,17 +589,25 @@ def generate(
inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id
)

is_greedy_or_beam_and_bucket = generation_config.bucket_size > 0 and (
self._get_generation_mode(generation_config, assistant_model) == GenerationMode.GREEDY_SEARCH
or self._get_generation_mode(generation_config, assistant_model) == GenerationMode.BEAM_SEARCH
is_greedy_or_beam_and_bucket = (
not generation_config.bucket_internal
and generation_config.bucket_size > 0
and (
self._get_generation_mode(generation_config, assistant_model) == GenerationMode.GREEDY_SEARCH
or self._get_generation_mode(generation_config, assistant_model) == GenerationMode.BEAM_SEARCH
)
)
model_kwargs["bucket_size"] = generation_config.bucket_size if generation_config.static_shapes else -1
model_kwargs["bucket_internal"] = generation_config.bucket_internal
model_kwargs["reduce_recompile"] = (
generation_config.reduce_recompile if generation_config.reduce_recompile is not None else False
)
if model_kwargs["reduce_recompile"]:
assert generation_config.bucket_size
if generation_config.reuse_cache:
if generation_config.bucket_internal:
assert generation_config.bucket_size >= 0, "bucket_internal and bucket_size flags set together"
assert generation_config.reuse_cache, "please set reuse_cache to use bucket_internal"
if generation_config.reuse_cache and not generation_config.bucket_internal:
assert generation_config.bucket_size <= 0, "reuse_cache and bucketing flags set together"

if generation_config.static_shapes:
Expand Down Expand Up @@ -714,6 +722,8 @@ def generate(
token_idx,
generation_config.kv_cache_fp8,
)
model_kwargs["kv_cache_len"] = calculated_max_length

if self.config.model_type in ["llama"]:
if self.config.max_position_embeddings < calculated_max_length:
unwrap_deepspeed_model(self).update_sincos_cache(seq_len=calculated_max_length)
Expand Down Expand Up @@ -1370,13 +1380,16 @@ def greedy_search(
hb_profer.start()
this_peer_finished = False # used by synced_gpus only
bucket_size = model_kwargs["bucket_size"]
prev_idx = None # avoiding calculate cache_idx when its value is not changing
bucket_internal = model_kwargs["bucket_internal"]
reduce_recompile = model_kwargs["reduce_recompile"]

prompt_len = input_ids.shape[-1]
if bucket_size >= 0:
inc = iter(incrementor(bucket_size, prompt_len))
if bucket_size > 0:
assert "position_ids" not in model_kwargs, "Untested path"
if not bucket_internal:
if bucket_size >= 0:
inc = iter(incrementor(bucket_size, prompt_len))
if bucket_size > 0:
assert "position_ids" not in model_kwargs, "Untested path"

while True:
if lazy_mode:
Expand All @@ -1393,11 +1406,22 @@ def greedy_search(
break

if bucket_size > 0:
# it will not have been padded if bucket_size > 0
params = next(inc)
input_ids, model_kwargs = self.update_model_kwargs_for_bucketing(
params, input_ids, model_kwargs, pad_token_id, bucket_size, reduce_recompile
)
if not bucket_internal:
# it will not have been padded if bucket_size > 0
params = next(inc)
input_ids, model_kwargs = self.update_model_kwargs_for_bucketing(
params, input_ids, model_kwargs, pad_token_id, bucket_size, reduce_recompile
)
else:
# Calculate slice idx for kv cache. Breaking down the kv cache in the attention block helps to reduce computation time.
if model_kwargs.get("token_idx") <= (model_kwargs["kv_cache_len"] // bucket_size) * bucket_size:
idx = torch.div(model_kwargs.get("token_idx") - 1, bucket_size, rounding_mode="floor")
if idx != prev_idx:
cache_idx = (idx.item() + 1) * bucket_size
model_kwargs["cache_idx"] = cache_idx
prev_idx = idx
else:
model_kwargs["cache_idx"] = model_kwargs["kv_cache_len"]

Copy link
Copy Markdown

@puneeshkhanna puneeshkhanna Feb 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xt574chen - this logic will work only when your total generated length is multiple of bucket size. For example consider an example of total length as 2060. So for tokens getting generated between 2048 and 2060, KV cache will be sliced till seq len 2048 and KV values between 2048 and 2060 won't be considered.

Please find updated logic below (spent a lot of time reviewing all the changes today):

if model_kwargs.get("token_idx") <= (model_kwargs["kv_cache_len"] // bucket_size) * bucket_size:
     idx = torch.div(model_kwargs.get("token_idx") - 1, bucket_size, rounding_mode="floor")
     cache_idx = (idx.item() + 1) * bucket_size
     model_kwargs["cache_idx"] = cache_idx
else:
     model_kwargs["cache_idx"] = model_kwargs["kv_cache_len"]

We can also further enhance a bit more for avoiding .item() call when the idx tensor is not changing. But lets avoid that minor enhancement for now. We can push separate PR later.

More importantly above logic needs to go in first so that the model logic works fine.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xt574chen Further enhanced code can be as below. I will let you decide the best course of action.

#Declare prev_idx = None outside while loop.

if model_kwargs.get("token_idx") <= (model_kwargs["kv_cache_len"] // bucket_size) * bucket_size:
     idx = torch.div(model_kwargs.get("token_idx") - 1, bucket_size, rounding_mode="floor")
     if idx != prev_idx:
         cache_idx = (idx.item() + 1) * bucket_size
         model_kwargs["cache_idx"] = cache_idx
         prev_idx = idx
else:
     model_kwargs["cache_idx"] = model_kwargs["kv_cache_len"]

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested the recommended logic (without the enhancement) quite a bit and seems to be working fine. @xt574chen - please test from your side and feel free to update anything. But the original code has the issue as I highlighted in earlier comments.

# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
Expand Down
17 changes: 17 additions & 0 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def pre_attn_forward(
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
cache_idx: int = None,
use_fused_rope: Optional[bool] = True,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Expand Down Expand Up @@ -263,6 +264,12 @@ def pre_attn_forward(
key_states = update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len)
value_states = update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len)

if cache_idx is not None and q_len == 1:
key_states = key_states[:, :, :cache_idx, :]
value_states = value_states[:, :, :cache_idx, :]
attention_mask = attention_mask[:, :, :, :cache_idx]
kv_seq_len = key_states.shape[-2]

if use_cache:
if reuse_cache:
past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape())
Expand Down Expand Up @@ -417,6 +424,7 @@ def forward(
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
cache_idx: int = None,
use_fused_rope: Optional[bool] = True,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Expand All @@ -441,6 +449,7 @@ def forward(
reuse_cache,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
cache_idx=cache_idx,
use_fused_rope=use_fused_rope,
)
self.self_attn.attention_all_reduce(output_pre_attn)
Expand Down Expand Up @@ -470,6 +479,7 @@ def pre_attn(
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
cache_idx: int = None,
use_fused_rope: Optional[bool] = True,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
hidden_states = self.input_layernorm(hidden_states)
Expand All @@ -485,6 +495,7 @@ def pre_attn(
reuse_cache,
use_flash_attention,
flash_attention_recompute,
cache_idx=cache_idx,
use_fused_rope=use_fused_rope,
)
return output_attn, attn_weights, present_key_value
Expand Down Expand Up @@ -534,6 +545,7 @@ def forward(
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
cache_idx: int = None,
use_fused_rope: Optional[bool] = True,
) -> Union[Tuple, BaseModelOutputWithPast]:
"""
Expand Down Expand Up @@ -646,6 +658,7 @@ def custom_forward(*inputs):
reuse_cache=reuse_cache,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
cache_idx=cache_idx,
use_fused_rope=use_fused_rope,
)

Expand Down Expand Up @@ -688,6 +701,7 @@ class GaudiLlamaForCausalLM(LlamaForCausalLM):

def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, kv_cache_fp8):
self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len, kv_cache_fp8)
self.kv_cache_len = max_seq_len

def reorder_kv_cache(self, beam_idx: torch.LongTensor):
return self.model.reorder_kv_cache(beam_idx)
Expand All @@ -713,6 +727,7 @@ def forward(
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
cache_idx: int = None,
use_fused_rope: Optional[bool] = True,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Expand All @@ -736,6 +751,7 @@ def forward(
reuse_cache=reuse_cache,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
cache_idx=cache_idx,
use_fused_rope=use_fused_rope,
)
hidden_states = outputs[0]
Expand Down Expand Up @@ -822,6 +838,7 @@ def prepare_inputs_for_generation(
"reuse_cache": reuse_cache,
"use_flash_attention": kwargs.get("use_flash_attention"),
"flash_attention_recompute": kwargs.get("flash_attention_recompute"),
"cache_idx": kwargs.get("cache_idx"),
}
)
return model_inputs
Expand Down