Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,8 +624,19 @@ def generate(
)
if model_kwargs["reduce_recompile"]:
assert generation_config.bucket_size
# Below condition checked explicitly since llama supports bucket_internal even without reuse_cache
if generation_config.bucket_internal:
assert generation_config.bucket_size >= 0, "please set bucket_size to use bucket_internal"
if generation_config.reuse_cache:
assert self.config.model_type in [
"llama",
"mistral",
"falcon",
], "reuse_cache only supported by llama, mistral and falcon at the moment"
if not generation_config.bucket_internal:
assert (
generation_config.bucket_size <= 0
), "please set bucket_internal along with reuse_cache and bucket_size"

if generation_config.static_shapes:
# Pad inputs to have static shapes during generation, this gives better performance than dynamic shapes on HPUs
Expand Down Expand Up @@ -1794,6 +1805,9 @@ def sample(
assert "position_ids" not in model_kwargs, "Untested path"

sample_first = True
model_kwargs["pad_done"] = False
model_kwargs["lazy_mode"] = lazy_mode

# auto-regressive generation
while True:
if lazy_mode:
Expand All @@ -1817,7 +1831,6 @@ def sample(
)

# prepare model inputs
model_kwargs["lazy_mode"] = lazy_mode
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

hpu_graphs_kwargs = self._get_hpu_graphs_kwargs(model_kwargs)
Expand Down Expand Up @@ -1922,6 +1935,21 @@ def sample(
if this_peer_finished and not synced_gpus:
break

if not model_kwargs.get("pad_done", False) and not model_kwargs.get("reuse_cache", False) \
and bucket_internal:
# Pad the returned pask key values tensors from prefill phase forward run to maximum length
# before starting the decode phase.
self._pad_past_key_values(model_kwargs)
model_kwargs["pad_done"] = True

if model_kwargs.get("use_hpu_graphs", False) and model_kwargs.get("limit_hpu_graphs", False) \
and not model_kwargs.get("reuse_cache", False) and bucket_internal:
# Clear HPU graphs input tensors of the decode phase after the full generation while loop
print("CLEAR HPU GRAPH INPUTS OF DECODE PHASE")
self.clear_inputs()
# Delete past key value tensors
self._remove_past_key_values(model_kwargs)

hb_profer.stop()
if streamer is not None:
streamer.end()
Expand Down