Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,9 @@ def generate(
# if generation_config.bucket_size <= 0, padding is handled by the generating fn (like greedy_search)
if generation_config.static_shapes and generation_config.bucket_size > 0:
assert (
generation_mode == GenerationMode.GREEDY_SEARCH or generation_mode == GenerationMode.BEAM_SEARCH
generation_mode == GenerationMode.GREEDY_SEARCH
or generation_mode == GenerationMode.SAMPLE
or generation_mode == GenerationMode.BEAM_SEARCH
), "generation_config.bucket_size > 0 supported only for greedy mode"

if streamer is not None and (generation_config.num_beams > 1):
Expand Down Expand Up @@ -1763,6 +1765,18 @@ def sample(
# Update cur_len in case of static shapes
cur_len = token_idx.item()

bucket_size = model_kwargs.get("bucket_size", -1)
prev_idx = -1 # avoiding calculate cache_idx when its value is not changing
bucket_internal = model_kwargs.get("bucket_internal", None)
reduce_recompile = model_kwargs.get("reduce_recompile", False)

prompt_len = input_ids.shape[-1]
if not bucket_internal:
if bucket_size >= 0:
inc = iter(incrementor(bucket_size, prompt_len))
if bucket_size > 0:
assert "position_ids" not in model_kwargs, "Untested path"

# auto-regressive generation
while True:
if lazy_mode:
Expand All @@ -1778,6 +1792,13 @@ def sample(
if this_peer_finished_flag.item() == 0.0:
break

if bucket_size > 0 and not bucket_internal:
# it will not have been padded if bucket_size > 0
params = next(inc)
input_ids, model_kwargs = self.update_model_kwargs_for_bucketing(
params, input_ids, model_kwargs, pad_token_id, bucket_size, reduce_recompile
)

# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

Expand Down Expand Up @@ -1849,6 +1870,17 @@ def sample(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
cur_len = cur_len + 1
if bucket_size > 0 and bucket_internal:
# Calculate slice idx for kv cache during the decode phase.
# Breaking down the kv cache in the attention block helps to reduce computation time.
if model_kwargs.get("token_idx_cpu") <= (model_kwargs["kv_cache_len"] // bucket_size) * bucket_size:
idx = (model_kwargs.get("token_idx_cpu") - 1) // bucket_size
if prev_idx != idx:
model_kwargs["cache_idx"] = (idx + 1) * bucket_size
prev_idx = idx
else:
model_kwargs["cache_idx"] = model_kwargs["kv_cache_len"]

# if eos_token was found in one sentence, set sentence to finished
if not ignore_eos and eos_token_id_tensor is not None:
unfinished_sequences = unfinished_sequences.mul(
Expand Down