Skip to content

Add clear hpu cache flag for stable perf (#59) #1634

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions examples/text-generation/run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,11 @@ def setup_parser(parser):
action="store_true",
help="Skip HPU Graph usage for first token to save memory",
)
parser.add_argument(
"--clear_hpu_graphs_cache",
action="store_true",
help="Clear HPU graphs cache",
)
parser.add_argument(
"--show_graphs_count",
action="store_true",
Expand Down
1 change: 1 addition & 0 deletions examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,7 @@ def setup_generation_config(args, model, assistant_model, tokenizer):
generation_config.trim_logits = args.trim_logits
generation_config.attn_softmax_bf16 = args.attn_softmax_bf16
generation_config.limit_hpu_graphs = args.limit_hpu_graphs
generation_config.clear_hpu_graphs_cache = args.clear_hpu_graphs_cache
generation_config.reuse_cache = args.reuse_cache
generation_config.reduce_recompile = args.reduce_recompile
if generation_config.reduce_recompile:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ class GaudiGenerationConfig(GenerationConfig):
is also running in lower precision.
limit_hpu_graphs (`bool`, *optional*):
Skip HPU Graph usage for first token to save memory
clear_hpu_graphs_cache (`bool`, *optional*):
Clear HPU Graph cache
reuse_cache (`bool`, *optional*):
Whether to reuse key/value cache for decoding. It should save memory.
bucket_size (`int`, *optional*):
Expand All @@ -46,6 +48,7 @@ def __init__(self, **kwargs):
self.ignore_eos = kwargs.get("ignore_eos", None)
self.attn_softmax_bf16 = kwargs.get("attn_softmax_bf16", None)
self.limit_hpu_graphs = kwargs.get("limit_hpu_graphs", None)
self.clear_hpu_graphs_cache = kwargs.get("clear_hpu_graphs_cache", None)
self.reuse_cache = kwargs.get("reuse_cache", None)
self.bucket_size = kwargs.get("bucket_size", -1)
self.bucket_internal = kwargs.get("bucket_internal", None)
Expand Down
11 changes: 10 additions & 1 deletion optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1262,6 +1262,9 @@ def generate(
model_kwargs["use_hpu_graphs"] = hpu_graphs
model_kwargs["limit_hpu_graphs"] = generation_config.limit_hpu_graphs

# determine whether to clear hpu graphs cache
model_kwargs["clear_hpu_graphs_cache"] = generation_config.clear_hpu_graphs_cache

# prepare for allocate kv cache
model_kwargs["reuse_cache"] = generation_config.reuse_cache

Expand Down Expand Up @@ -2612,8 +2615,14 @@ def _sample(
and not model_kwargs.get("reuse_cache", False)
and bucket_internal
):
# Clear HPU graphs cache
if model_kwargs.get("clear_hpu_graphs_cache", False):
self.clear_cache()

# Clear HPU graphs input tensors of the decode phase after the full generation while loop
self.clear_inputs()
else:
self.clear_inputs()

# Delete past key value tensors
self._remove_past_key_values(model_kwargs)

Expand Down