Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@ def setup_env(args):
os.environ.setdefault("PT_HPU_LAZY_ACC_PAR_MODE", "0")
os.environ.setdefault("PT_HPU_ENABLE_LAZY_COLLECTIVES", "true")

if args.use_hpu_graphs and args.limit_hpu_graphs and not args.reuse_cache and args.bucket_internal:
# Based upon above conditions and below env variable,
# we can call HPU graphs clear_inputs().
os.environ.setdefault("PT_HPUGRAPH_DISABLE_TENSOR_CACHE", "1")

# Tweak generation so that it runs faster on Gaudi
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi

Expand Down
87 changes: 80 additions & 7 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,29 @@ def _expand_dict_for_generation(dict_to_expand):

return input_ids, model_kwargs

def _pad_past_key_values(self, model_kwargs):
pad_amount = model_kwargs.get("kv_cache_pad_len", 0)
if model_kwargs["past_key_values"]:
for i in range(len(model_kwargs["past_key_values"])):
for j in range(len(model_kwargs["past_key_values"][i])):
if torch.is_tensor(model_kwargs["past_key_values"][i][j]):
model_kwargs["past_key_values"][i][j] = torch.nn.functional.pad(
model_kwargs["past_key_values"][i][j], (0, 0, 0, pad_amount)
)
if model_kwargs.get("lazy_mode", False):
self.htcore_generation.mark_step()

def _remove_past_key_values(self, model_kwargs):
if model_kwargs["past_key_values"]:
for i in range(len(model_kwargs["past_key_values"])):
for j in range(len(model_kwargs["past_key_values"][i])):
if torch.is_tensor(model_kwargs["past_key_values"][i][j]):
t = model_kwargs["past_key_values"][i][j]
del t
model_kwargs["past_key_values"][i][j] = None
del model_kwargs["past_key_values"]
model_kwargs["past_key_values"] = None

def _update_model_kwargs_for_generation(
self,
outputs: ModelOutput,
Expand All @@ -278,10 +301,11 @@ def _update_model_kwargs_for_generation(
"""
# mark to identify starting from second token
model_kwargs["first_token"] = False
# update past_key_values
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
outputs, standardize_cache_format=standardize_cache_format
)
if not model_kwargs.get("pad_done", False):
# update past_key_values
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
outputs, standardize_cache_format=standardize_cache_format
)
if getattr(outputs, "state", None) is not None:
model_kwargs["state"] = outputs.state

Expand Down Expand Up @@ -745,6 +769,9 @@ def generate(
)
if model_kwargs["reduce_recompile"]:
assert generation_config.bucket_size
# Below condition checked explicitly since llama supports bucket_internal even without reuse_cache
if generation_config.bucket_internal:
assert generation_config.bucket_size >= 0, "please set bucket_size to use bucket_internal"
if generation_config.reuse_cache:
assert self.config.model_type in [
"llama",
Expand Down Expand Up @@ -878,6 +905,7 @@ def generate(
model_kwargs["attn_softmax_bf16"] = generation_config.attn_softmax_bf16

# determine whether limit_hpu_graphs needs to be used
model_kwargs["use_hpu_graphs"] = hpu_graphs
model_kwargs["limit_hpu_graphs"] = generation_config.limit_hpu_graphs

# prepare for allocate kv cache
Expand All @@ -899,7 +927,9 @@ def generate(
unwrap_deepspeed_model(self).allocate_kv_cache(
bs * generation_config.num_beams, calculated_max_length, token_idx + num_virtual_tokens
)
model_kwargs["kv_cache_len"] = calculated_max_length
if generation_config.use_cache:
model_kwargs["kv_cache_len"] = calculated_max_length
model_kwargs["kv_cache_pad_len"] = generation_config.max_new_tokens

if self.config.model_type in ["llama", "falcon", "mistral"]:
if self.config.max_position_embeddings < calculated_max_length:
Expand Down Expand Up @@ -1618,6 +1648,8 @@ def _greedy_search(
cur_len = token_idx.item()

time_to_first_token_done = False
model_kwargs["pad_done"] = False
model_kwargs["lazy_mode"] = lazy_mode
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
if lazy_mode:
self.htcore_generation.mark_step()
Expand All @@ -1630,7 +1662,6 @@ def _greedy_search(
)

# prepare model inputs
model_kwargs["lazy_mode"] = lazy_mode
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

hpu_graphs_kwargs = self._get_hpu_graphs_kwargs(model_kwargs)
Expand Down Expand Up @@ -1737,6 +1768,15 @@ def _greedy_search(
)
this_peer_finished = unfinished_sequences.max() == 0

if (
not model_kwargs.get("pad_done", False)
and not model_kwargs.get("reuse_cache", False)
and bucket_internal
):
# Pad the returned pask key values tensors from prefill phase forward run to maximum length
# before starting the decode phase.
self._pad_past_key_values(model_kwargs)
model_kwargs["pad_done"] = True
hb_profer.step()
if hb_gen_time is not None:
if not time_to_first_token_done:
Expand All @@ -1746,6 +1786,17 @@ def _greedy_search(
torch_hpu.synchronize()
hb_gen_time.step()

if (
model_kwargs.get("use_hpu_graphs", False)
and model_kwargs.get("limit_hpu_graphs", False)
and not model_kwargs.get("reuse_cache", False)
and bucket_internal
):
# Clear HPU graphs input tensors of the decode phase after the full generation while loop
self.clear_inputs()
# Delete past key value tensors
self._remove_past_key_values(model_kwargs)

hb_profer.stop()
if streamer is not None:
streamer.end()
Expand Down Expand Up @@ -2016,6 +2067,8 @@ def _sample(

# auto-regressive generation
time_to_first_token_done = False
model_kwargs["pad_done"] = False
model_kwargs["lazy_mode"] = lazy_mode
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
if lazy_mode:
self.htcore_generation.mark_step()
Expand All @@ -2028,7 +2081,6 @@ def _sample(
)

# prepare model inputs
model_kwargs["lazy_mode"] = lazy_mode
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

hpu_graphs_kwargs = self._get_hpu_graphs_kwargs(model_kwargs)
Expand Down Expand Up @@ -2130,6 +2182,16 @@ def _sample(
)
this_peer_finished = unfinished_sequences.max() == 0

if (
not model_kwargs.get("pad_done", False)
and not model_kwargs.get("reuse_cache", False)
and bucket_internal
):
# Pad the returned pask key values tensors from prefill phase forward run to maximum length
# before starting the decode phase.
self._pad_past_key_values(model_kwargs)
model_kwargs["pad_done"] = True

hb_profer.step()
if hb_gen_time is not None:
if not time_to_first_token_done:
Expand All @@ -2139,6 +2201,17 @@ def _sample(
torch_hpu.synchronize()
hb_gen_time.step()

if (
model_kwargs.get("use_hpu_graphs", False)
and model_kwargs.get("limit_hpu_graphs", False)
and not model_kwargs.get("reuse_cache", False)
and bucket_internal
):
# Clear HPU graphs input tensors of the decode phase after the full generation while loop
self.clear_inputs()
# Delete past key value tensors
self._remove_past_key_values(model_kwargs)

hb_profer.stop()
if streamer is not None:
streamer.end()
Expand Down
2 changes: 2 additions & 0 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ def adapt_transformers_to_gaudi():
GaudiGenerationMixin.update_model_kwargs_for_bucketing
)
transformers.generation.GenerationMixin._get_hpu_graphs_kwargs = GaudiGenerationMixin._get_hpu_graphs_kwargs
transformers.generation.GenerationMixin._pad_past_key_values = GaudiGenerationMixin._pad_past_key_values
transformers.generation.GenerationMixin._remove_past_key_values = GaudiGenerationMixin._remove_past_key_values
transformers.generation.GenerationMixin._expand_inputs_for_generation = staticmethod(
GaudiGenerationMixin._expand_inputs_for_generation
)
Expand Down
15 changes: 12 additions & 3 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,8 @@ def pre_attn_forward(
past_value = torch.zeros(
key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device
)
past_key_value = (past_key, past_value)
# Return list instead of tuple
past_key_value = [past_key, past_value]
if (
token_idx is not None
and num_virtual_tokens is not None
Expand Down Expand Up @@ -479,6 +480,11 @@ def pre_attn_forward(
if not output_attentions:
attn_weights = None

if not reuse_cache and token_idx is not None and cache_idx is not None and q_len == 1:
# Return only past key value shapes and not the tensors during decode phase (q len is 1)
# to avoid making past key values as persistent output tensors of HPU graphs.
past_key_value = (past_key_value[0].shape, past_key_value[1].shape)

return attn_output, attn_weights, past_key_value

def attention_all_reduce(self, attn_output):
Expand Down Expand Up @@ -825,6 +831,7 @@ def forward(
use_flash_attention,
flash_attention_recompute,
flash_attention_causal_mask,
None,
)
else:
layer_outputs = decoder_layer(
Expand Down Expand Up @@ -1003,6 +1010,7 @@ def prepare_inputs_for_generation(
past_length = 0

reuse_cache = kwargs.get("reuse_cache")
bucket_internal = kwargs.get("bucket_internal")
if past_key_values is not None:
if token_idx is not None:
input_ids = torch.index_select(input_ids, 1, token_idx - 1)
Expand Down Expand Up @@ -1041,8 +1049,9 @@ def prepare_inputs_for_generation(
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]
elif reuse_cache and token_idx is not None:
# With reuse_cache, KV cache is pre allocated hence for the 1st token we can slice the inputs till token idx for the fwd pass
elif (reuse_cache or bucket_internal) and token_idx is not None:
# KV cache is pre allocated with reuse cache or will be padded with bucket internal
# hence for the 1st token we can slice the inputs till token idx for the fwd pass.
input_ids = input_ids[:, :token_idx]
attention_mask = attention_mask[:, :token_idx]

Expand Down