Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 5 additions & 14 deletions examples/text-generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ Here are a few settings you may be interested in:
- `--prompt` to benchmark the model on one or several prompts of your choice
- `--attn_softmax_bf16` to run attention softmax layer in bfloat16 precision provided that the model (such as Llama) supports it
- `--trim_logits` to calculate logits only for the last token in the first time step provided that the model (such as Llama) supports it
- `--fp8` Enable Quantization to fp8

For example, you can reproduce the results presented in [this blog post](https://huggingface.co/blog/habana-gaudi-2-bloom) with the following command:
```bash
Expand Down Expand Up @@ -283,8 +282,7 @@ QUANT_CONFIG=./quantization_config/maxabs_quant.json python ../gaudi_spawn.py \
--use_kv_cache \
--reuse_cache \
--bf16 \
--batch_size 1 \
--fp8
--batch_size 1
```

Alternatively, here is another example to quantize the model based on previous measurements for LLama2-70b:
Expand All @@ -301,8 +299,7 @@ QUANT_CONFIG=./quantization_config/maxabs_quant.json python ../gaudi_spawn.py \
--batch_size 277 \
--max_new_tokens 2048 \
--max_input_tokens 2048 \
--limit_hpu_graphs \
--fp8
--limit_hpu_graphs
```

Here is an example to measure the tensor quantization statistics on Mixtral-8x7B with 1 card:
Expand All @@ -328,8 +325,7 @@ QUANT_CONFIG=./quantization_config/maxabs_quant_mixtral.json python run_generati
--bucket_size 128 \
--max_new_tokens 2048 \
--batch_size 16 \
--bf16 \
--fp8
--bf16
```

Here is an example to measure the tensor quantization statistics on Falcon-180B with 8 cards:
Expand Down Expand Up @@ -360,8 +356,7 @@ QUANT_CONFIG=./quantization_config/maxabs_quant.json python ../gaudi_spawn.py \
--batch_size 110 \
--bf16 \
--reuse_cache \
--trim_logits \
--fp8
--trim_logits
```

Here is an example to measure the tensor quantization statistics on phi-2 with 1 card:
Expand Down Expand Up @@ -389,13 +384,9 @@ QUANT_CONFIG=./quantization_config/maxabs_quant_phi.json python run_generation.p
--batch_size 1 \
--bf16 \
--trim_logits \
--reuse_cache \
--fp8
--reuse_cache
```
Comment thread
ssarkar2 marked this conversation as resolved.

`--fp8` is required to enable quantization in fp8.


### Using Habana Flash Attention

Habana Flash Attention addresses large sequence lengths on prompt stage of inference. Using causal attention mask on prompt stage requires input sequences in batch to be of the same length, but can provide a memory saving, thus enabling higher batch sizes.
Expand Down
36 changes: 26 additions & 10 deletions examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,12 @@
model_on_meta,
write_checkpoints_json,
)
from optimum.habana.utils import check_habana_frameworks_version, check_optimum_habana_min_version, set_seed
from optimum.habana.utils import (
check_habana_frameworks_version,
check_optimum_habana_min_version,
get_habana_frameworks_version,
set_seed,
)


def adjust_batch(batch, size):
Expand Down Expand Up @@ -96,6 +101,22 @@ def setup_distributed(args):
args.global_rank = int(os.getenv("RANK", "0"))


def setup_inference(args, model):
import habana_frameworks.torch.core as htcore

habana_version = get_habana_frameworks_version()

print("Initializing inference mode")
# Keeping the if-else here for back compat. TODO remove later
if habana_version.major >= 1 and habana_version.minor >= 16:
htcore.hpu_initialize(model, mark_only_scales_as_const=True)
else:
const_marking = os.getenv("ENABLE_CONST_MARKING", "True")
if const_marking == "True":
htcore.hpu_initialize(model)
return model


def setup_const_serialization(const_serialization_path):
import uuid

Expand Down Expand Up @@ -132,7 +153,7 @@ def setup_device(args):
if args.device == "hpu":
import habana_frameworks.torch.core as htcore

if args.fp8:
if args.quant_config:
htcore.hpu_set_env()
return torch.device(args.device)

Expand Down Expand Up @@ -405,7 +426,7 @@ def initialize_model(args, logger):
set_seed(args.seed)
get_repo_root(args.model_name_or_path, local_rank=args.local_rank, token=args.token)
use_deepspeed = args.world_size > 0
if use_deepspeed or args.bf16 or args.fp8:
if use_deepspeed or args.bf16:
model_dtype = torch.bfloat16
else:
model_dtype = torch.float
Expand All @@ -429,13 +450,8 @@ def initialize_model(args, logger):

if args.const_serialization_path:
setup_const_serialization(args.const_serialization_path)
if args.fp8:
import habana_frameworks.torch.core as htcore

print("Initializing inference mode")
const_marking = os.getenv("ENABLE_CONST_MARKING", "True")
if const_marking == "True":
htcore.hpu_initialize(model)
if args.quant_config:
model = setup_inference(args, model)
init_end = time.perf_counter()
logger.info(f"Args: {args}")
logger.info(f"device: {args.device}, n_hpu: {args.world_size}, bf16: {model_dtype == torch.bfloat16}")
Expand Down
1 change: 0 additions & 1 deletion tests/test_text_generation_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,6 @@ def _test_text_generation(
env_variables["QUANT_CONFIG"] = os.path.join(
path_to_example_dir, "text-generation/quantization_config/maxabs_quant.json"
)
command.insert(-2, "--fp8")
command.insert(-2, "--warmup 1")
command.insert(-2, "--n_iterations 2")
if "Llama-2" in model_name or "Mistral" in model_name:
Expand Down