Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions examples/common_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from argparse import ArgumentParser


def add_profiling_args(parser: ArgumentParser) -> None:
parser.add_argument(
"--profiling_warmup_steps",
default=0,
type=int,
help="Number of steps to ignore for profiling.",
)
parser.add_argument(
"--profiling_steps",
default=0,
type=int,
help="Number of steps to capture for profiling.",
)
parser.add_argument(
"--profiling_record_shapes",
action="store_true",
help="Record shapes when enabling profiling.",
)
25 changes: 8 additions & 17 deletions examples/text-generation/run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import logging
import math
import os
import sys
import struct
from itertools import cycle
from pathlib import Path
Expand All @@ -40,6 +41,12 @@
)


project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../"))
if project_root not in sys.path:
sys.path.insert(0, project_root)
Comment on lines +44 to +46
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need these lines?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to import this:
from examples.common_parser import add_profiling_args
I do believe it's good approach to use the examples folder for shared code among examples
alternatively:

  • I could use a relative import, but these scripts are not run as a module.
  • I could modify PYTHONPATH, but this would cause changes to all tests and instructions for running these scripts.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@regisss can you check my comment please

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alexey-belyakov, I would suggest to wait for #1931 to be merged as there is some rebase of profiling code.

Besides that, probably another solution in case you do not want to rewrite code multiple times is to create a file like examples_utils.py, or the like, under optimum/habana which is already in the path. To avoid the solution you provided which is not very clear.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly I'm also fine in not creating a new file for the three profiling args that are repeating : ) But I can understand the pain

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optimum/habana does not contain code related to examples. examples folder is the most obvious folder for files related to examples.
It is always possible to rename the file after adding more logic to it.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that solution is okay. Can you still check if it still holds after #1931 as suggested by @12010486 please?

from examples.common_parser import add_profiling_args # noqa: E402


logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
Expand Down Expand Up @@ -140,23 +147,7 @@ def setup_parser(parser):
type=int,
help="Seed to use for random generation. Useful to reproduce your runs with `--do_sample`.",
)
parser.add_argument(
"--profiling_warmup_steps",
default=0,
type=int,
help="Number of steps to ignore for profiling.",
)
parser.add_argument(
"--profiling_steps",
default=0,
type=int,
help="Number of steps to capture for profiling.",
)
parser.add_argument(
"--profiling_record_shapes",
action="store_true",
help="Record shapes when enabling profiling.",
)
add_profiling_args(parser)
parser.add_argument(
"--profile_whole_sequences",
action="store_true",
Expand Down
15 changes: 15 additions & 0 deletions examples/video-comprehension/run_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import json
import logging
import os
import sys
import time
from pathlib import Path

Expand All @@ -32,6 +33,12 @@
)


project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../"))
if project_root not in sys.path:
sys.path.insert(0, project_root)
Comment on lines +36 to +38
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same

from examples.common_parser import add_profiling_args # noqa: E402


logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
Expand Down Expand Up @@ -118,6 +125,7 @@ def main():
help="Whether to enable Habana Flash Attention in recompute mode on first token generation. This gives an opportunity of splitting graph internally which helps reduce memory consumption.",
)

add_profiling_args(parser)
args = parser.parse_args()

os.environ.setdefault("EXPERIMENTAL_WEIGHT_SHARING", "FALSE")
Expand Down Expand Up @@ -186,6 +194,9 @@ def main():
)
torch.hpu.synchronize()

from optimum.habana.utils import HabanaProfile

HabanaProfile.enable()
start = time.perf_counter()
for i in range(args.n_iterations):
generate_ids = model.generate(
Expand All @@ -196,12 +207,16 @@ def main():
ignore_eos=args.ignore_eos,
use_flash_attention=args.use_flash_attention,
flash_attention_recompute=args.flash_attention_recompute,
profiling_steps=args.profiling_steps,
profiling_warmup_steps=args.profiling_warmup_steps,
profiling_record_shapes=args.profiling_record_shapes,
)
generate_texts = processor.batch_decode(
generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
end = time.perf_counter()
duration = end - start
HabanaProfile.disable()

# Let's calculate the number of generated tokens
n_input_tokens = inputs["input_ids"].shape[1]
Expand Down
Loading