diff --git a/examples/common_parser.py b/examples/common_parser.py new file mode 100644 index 0000000000..2f0d329da3 --- /dev/null +++ b/examples/common_parser.py @@ -0,0 +1,21 @@ +from argparse import ArgumentParser + + +def add_profiling_args(parser: ArgumentParser) -> None: + parser.add_argument( + "--profiling_warmup_steps", + default=0, + type=int, + help="Number of steps to ignore for profiling.", + ) + parser.add_argument( + "--profiling_steps", + default=0, + type=int, + help="Number of steps to capture for profiling.", + ) + parser.add_argument( + "--profiling_record_shapes", + action="store_true", + help="Record shapes when enabling profiling.", + ) diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index ceda678419..e00c4768db 100755 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -23,6 +23,7 @@ import logging import math import os +import sys import struct from itertools import cycle from pathlib import Path @@ -40,6 +41,12 @@ ) +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")) +if project_root not in sys.path: + sys.path.insert(0, project_root) +from examples.common_parser import add_profiling_args # noqa: E402 + + logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", @@ -140,23 +147,7 @@ def setup_parser(parser): type=int, help="Seed to use for random generation. Useful to reproduce your runs with `--do_sample`.", ) - parser.add_argument( - "--profiling_warmup_steps", - default=0, - type=int, - help="Number of steps to ignore for profiling.", - ) - parser.add_argument( - "--profiling_steps", - default=0, - type=int, - help="Number of steps to capture for profiling.", - ) - parser.add_argument( - "--profiling_record_shapes", - action="store_true", - help="Record shapes when enabling profiling.", - ) + add_profiling_args(parser) parser.add_argument( "--profile_whole_sequences", action="store_true", diff --git a/examples/video-comprehension/run_example.py b/examples/video-comprehension/run_example.py index b53679fb0b..258cf444c8 100644 --- a/examples/video-comprehension/run_example.py +++ b/examples/video-comprehension/run_example.py @@ -17,6 +17,7 @@ import json import logging import os +import sys import time from pathlib import Path @@ -32,6 +33,12 @@ ) +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")) +if project_root not in sys.path: + sys.path.insert(0, project_root) +from examples.common_parser import add_profiling_args # noqa: E402 + + logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", @@ -118,6 +125,7 @@ def main(): help="Whether to enable Habana Flash Attention in recompute mode on first token generation. This gives an opportunity of splitting graph internally which helps reduce memory consumption.", ) + add_profiling_args(parser) args = parser.parse_args() os.environ.setdefault("EXPERIMENTAL_WEIGHT_SHARING", "FALSE") @@ -186,6 +194,9 @@ def main(): ) torch.hpu.synchronize() + from optimum.habana.utils import HabanaProfile + + HabanaProfile.enable() start = time.perf_counter() for i in range(args.n_iterations): generate_ids = model.generate( @@ -196,12 +207,16 @@ def main(): ignore_eos=args.ignore_eos, use_flash_attention=args.use_flash_attention, flash_attention_recompute=args.flash_attention_recompute, + profiling_steps=args.profiling_steps, + profiling_warmup_steps=args.profiling_warmup_steps, + profiling_record_shapes=args.profiling_record_shapes, ) generate_texts = processor.batch_decode( generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False ) end = time.perf_counter() duration = end - start + HabanaProfile.disable() # Let's calculate the number of generated tokens n_input_tokens = inputs["input_ids"].shape[1]