Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions examples/text-generation/run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,19 @@ def setup_parser(parser):
action="store_true",
help="Whether to enable Habana Flash Attention, provided that the model supports it.",
)
parser.add_argument(
"--torch_compile",
action="store_true",
help="Whether to use torch compiled model or not.",
)
parser.add_argument("--temperature", default=1.0, type=float, help="Temperature value for text generation")
parser.add_argument("--top_p", default=1.0, type=float, help="Top_p value for generating text via sampling")

args = parser.parse_args()

if args.torch_compile:
args.use_hpu_graphs = False

if not args.use_hpu_graphs:
args.limit_hpu_graphs = False

Expand All @@ -247,6 +255,10 @@ def main():
args = setup_parser(parser)
model, tokenizer, generation_config = initialize_model(args, logger)

use_lazy_mode = True
if args.torch_compile and model.config.model_type == "llama":
use_lazy_mode = False

import habana_frameworks.torch.hpu as torch_hpu

if args.dataset_name is None:
Expand Down Expand Up @@ -299,7 +311,7 @@ def generate(size=None, reduce_recompile=False):
outputs = model.generate(
**input_tokens,
generation_config=generation_config,
lazy_mode=True,
lazy_mode=use_lazy_mode,
hpu_graphs=args.use_hpu_graphs,
profiling_steps=args.profiling_steps,
profiling_warmup_steps=args.profiling_warmup_steps,
Expand Down Expand Up @@ -479,7 +491,7 @@ def generate_dataset(batch):
outputs = model.generate(
**batch,
generation_config=generation_config,
lazy_mode=True,
lazy_mode=use_lazy_mode,
hpu_graphs=args.use_hpu_graphs,
profiling_steps=args.profiling_steps,
profiling_warmup_steps=args.profiling_warmup_steps,
Expand Down
11 changes: 10 additions & 1 deletion examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def setup_env(args):
check_min_version("4.34.0")
check_optimum_habana_min_version("1.9.0.dev0")

if args.global_rank == 0:
if args.global_rank == 0 and not args.torch_compile:
os.environ.setdefault("GRAPH_VISUALIZATION", "true")
shutil.rmtree(".graph_dumps", ignore_errors=True)

Expand Down Expand Up @@ -151,6 +151,11 @@ def patch_scoped_linear_all_reduce(model):
patch_scoped_linear_all_reduce(module)


def get_torch_compiled_model(model):
model.model = torch.compile(model.model, backend="aot_hpu_inference_backend")
Comment thread
regisss marked this conversation as resolved.
return model


def setup_model(args, model_dtype, model_kwargs, logger):
logger.info("Single-device run.")

Expand All @@ -170,6 +175,10 @@ def setup_model(args, model_dtype, model_kwargs, logger):
model = wrap_in_hpu_graph(model, hash_with_views=not args.skip_hash_with_views)
else:
model = wrap_in_hpu_graph(model)

if args.torch_compile and model.config.model_type == "llama":
model = get_torch_compiled_model(model)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kausikmaiti Can we add model specific check as generation using torch.compile isn't verified on models

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added model specific check in separate commit. Please review.


return model


Expand Down
35 changes: 33 additions & 2 deletions tests/test_text_generation_example.py
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

os.environ["WORLD_SIZE"] = "0" ? WORLD_SIZE should be set to 1 for 1x runs

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As you mentioned offline, WORLD_SIZE setting does not matter, as I'm not using deepspeed / gaudi_spawn.py script.
Also as per my observation, if I don't set WORLD_SIZE=0, due to the logic like "use_deepspeed = args.world_size > 0", setup_distributed_model() gets called and the test fails at very early stage while importing deepspeed. This is not the expectation.

Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
("meta-llama/Llama-2-70b-hf", 58.2750262232098),
("facebook/opt-66b", 28.16154122335556),
],
"torch_compile": [
("meta-llama/Llama-2-7b-hf", 8.95169640119334),
],
}
else:
# Gaudi1 CI baselines
Expand All @@ -50,13 +53,22 @@
"deepspeed": [
("bigscience/bloomz-7b1", 27.34439410425298),
],
"torch_compile": [],
}


def _test_text_generation(model_name: str, baseline: float, token: str, deepspeed: bool = False, world_size: int = 8):
def _test_text_generation(
model_name: str,
baseline: float,
token: str,
deepspeed: bool = False,
world_size: int = 8,
torch_compile: bool = False,
):
command = ["python3"]
path_to_example_dir = Path(__file__).resolve().parent.parent / "examples"

deepspeed = deepspeed and not torch_compile
if deepspeed:
command += [
f"{path_to_example_dir / 'gaudi_spawn.py'}",
Expand All @@ -68,11 +80,22 @@ def _test_text_generation(model_name: str, baseline: float, token: str, deepspee
f"{path_to_example_dir / 'text-generation' / 'run_generation.py'}",
f"--model_name_or_path {model_name}",
"--batch_size 1",
"--use_hpu_graphs",
"--use_kv_cache",
"--max_new_tokens 100",
]

if torch_compile:
command += [
"--attn_softmax_bf16",
"--reuse_cache",
"--trim_logits",
"--torch_compile",
]
else:
command += [
"--use_hpu_graphs",
]

if not deepspeed:
command.append("--bf16")

Expand Down Expand Up @@ -115,3 +138,11 @@ def test_text_generation_bf16(model_name: str, baseline: float, token: str):
def test_text_generation_deepspeed(model_name: str, baseline: float, token: str):
world_size = 2 if "opt-66b" in model_name else 8
_test_text_generation(model_name, baseline, token, deepspeed=True, world_size=world_size)


@pytest.mark.parametrize("model_name, baseline", MODELS_TO_TEST["torch_compile"])
def test_text_generation_torch_compile(model_name: str, baseline: float, token: str):
os.environ["PT_ENABLE_INT64_SUPPORT"] = "1"
os.environ["PT_HPU_LAZY_MODE"] = "0"
os.environ["WORLD_SIZE"] = "0"
_test_text_generation(model_name, baseline, token, torch_compile=True)