Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"method": "HOOKS",
"mode": "QUANTIZE",
"observer": "maxabs",
"scale_method": "ACT_MAXABS_HW_WEIGHTS_PCS_MAXABS_POW2",
"whitelist": {"types": [], "names": []},
"blacklist": {"types": [], "names": ["lm_head"]},
"dump_stats_path": "./hqt_output/measure",
"dump_stats_xlsx_path": "./hqt_output/measure/fp8stats.xlsx"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"method": "HOOKS",
"mode": "MEASURE",
"observer": "maxabs",
"whitelist": {"types": [], "names": []},
"blacklist": {"types": [], "names": ["lm_head"]},
"dump_stats_path": "./hqt_output/measure",
"dump_stats_xlsx_path": "./hqt_output/measure/fp8stats.xlsx"
}
10 changes: 10 additions & 0 deletions examples/text-generation/quantization_config/maxabs_quant.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"method": "HOOKS",
"mode": "QUANTIZE",
"observer": "maxabs",
"scale_method": "maxabs_hw",
"whitelist": {"types": [], "names": []},
"blacklist": {"types": [], "names": ["lm_head"]},
"dump_stats_path": "./hqt_output/measure",
"dump_stats_xlsx_path": "./hqt_output/measure/fp8stats.xlsx"
}
10 changes: 10 additions & 0 deletions examples/text-generation/quantization_config/unit_scale_quant.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"method": "HOOKS",
"mode": "QUANTIZE",
"observer": "maxabs",
"scale_method": "unit_scale",
"whitelist": {"types": [], "names": []},
"blacklist": {"types": [], "names": ["lm_head"]},
"dump_stats_path": "./hqt_output/measure",
"dump_stats_xlsx_path": "./hqt_output/measure/fp8stats.xlsx"
}
11 changes: 10 additions & 1 deletion examples/text-generation/run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import json
import logging
import math
import os
import time
from itertools import cycle
from pathlib import Path
Expand Down Expand Up @@ -223,7 +224,7 @@ def setup_parser(parser):
parser.add_argument(
"--kv_cache_fp8",
action="store_true",
help="Store kv-cache in float8 when kv-cache is used",
help="Store kv-cache in float8 when kv-cache is used. Can't use this argument together with QUANT_CONFIG env var",
)
parser.add_argument("--fp8", action="store_true", help="Enable Quantization to fp8")
parser.add_argument(
Expand All @@ -239,6 +240,11 @@ def setup_parser(parser):
if not args.use_hpu_graphs:
args.limit_hpu_graphs = False

args.quant_config = os.getenv("QUANT_CONFIG", "")
if args.quant_config and args.kv_cache_fp8:
# can't use both quant_config and kv_cache_fp8, since quant_config may trigger kv cache quantization
# with habana quantization toolkit
raise parser.error("Can't use QUANT_CONFIG env var with kv_cache_fp8 argument")
return args


Expand Down Expand Up @@ -539,6 +545,9 @@ def generate_dataset(batch):
if prompt_length > 0:
print(f"Graph compilation duration = {compilation_duration} seconds")
print(separator)
if args.quant_config:
import habana_quantization_toolkit
habana_quantization_toolkit.finish_measurements(model)


if __name__ == "__main__":
Expand Down
13 changes: 12 additions & 1 deletion examples/text-generation/run_lm_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def __init__(self, tokenizer, model, args, options):
if self.model.config.model_type == "llama":
self.model_inputs.update(
{
"reuse_cache" : self.options.reuse_cache,
"attn_softmax_bf16": self.options.attn_softmax_bf16,
}
)
Expand Down Expand Up @@ -125,10 +126,17 @@ def find_bucket(self, length):
return [b for b in self.buckets if b >= length][0]

def _model_call(self, inps):
seq_length = inps.shape[-1]
bs, seq_length = inps.shape
padding_length = 0
if self.options.static_shapes:
bucket_length = self.find_bucket(seq_length)
if self.options.use_cache and self.options.reuse_cache:
self.model.allocate_kv_cache(
bs,
bucket_length + 1,
bucket_length,
False,
)
padding_length = bucket_length - seq_length
inps = F.pad(inps, (0, padding_length), value=self.model.config.pad_token_id)
logits = self.model(inps.to(self._device), **self.model_inputs)["logits"].cpu()
Expand Down Expand Up @@ -164,6 +172,9 @@ def main():
print("{:35} = {} GB".format(k[:-5].replace("_", " ").capitalize(), v))
json.dump(results, open(args.output_file, "w"), indent=2)
print(json.dumps(results, indent=2))
if args.quant_config:
import habana_quantization_toolkit
habana_quantization_toolkit.finish_measurements(model)


if __name__ == "__main__":
Expand Down
19 changes: 14 additions & 5 deletions examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,16 @@ def setup_distributed(args):
args.global_rank = int(os.getenv("RANK", "0"))


def setup_quantization(model):
def setup_quantization(args, model):
import habana_frameworks.torch.core as htcore
from habana_frameworks.torch.core.quantization import _check_params_as_const, _mark_params_as_const
from habana_frameworks.torch.hpu import hpu

print("Initializing inference with quantization")
_mark_params_as_const(model)
_check_params_as_const(model)

hpu.enable_quantization()
if not args.quant_config:
hpu.enable_quantization()
htcore.hpu_initialize(model)
return model

Expand All @@ -114,6 +114,8 @@ def setup_env(args):
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.34.0")
check_optimum_habana_min_version("1.9.0.dev0")
#TODO: SW-167588 - WA for memory issue in hqt prep_model
os.environ.setdefault('EXPERIMENTAL_WEIGHT_SHARING', 'FALSE')

if args.global_rank == 0:
os.environ.setdefault("GRAPH_VISUALIZATION", "true")
Expand Down Expand Up @@ -158,6 +160,9 @@ def setup_model(args, model_dtype, model_kwargs, logger):
model = peft_model(args, model_dtype, logger, **model_kwargs)
else:
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, torch_dtype=model_dtype, **model_kwargs)
if args.quant_config:
import habana_quantization_toolkit
habana_quantization_toolkit.prep_model(model)
model = model.eval().to(args.device)

if args.use_hpu_graphs:
Expand All @@ -178,7 +183,7 @@ def setup_distributed_model(args, model_dtype, model_kwargs, logger):

logger.info("DeepSpeed is enabled.")
deepspeed.init_distributed(dist_backend="hccl")
config = AutoConfig.from_pretrained(args.model_name_or_path, **model_kwargs)
config = AutoConfig.from_pretrained(args.model_name_or_path, torch_dtype=model_dtype, **model_kwargs)
load_to_meta = model_on_meta(config)

if load_to_meta:
Expand Down Expand Up @@ -227,6 +232,10 @@ def setup_distributed_model(args, model_dtype, model_kwargs, logger):
model = model.module
if model.config.model_type == "llama":
patch_scoped_linear_all_reduce(model)

if args.quant_config:
import habana_quantization_toolkit
habana_quantization_toolkit.prep_model(model)
return model


Expand Down Expand Up @@ -359,7 +368,7 @@ def initialize_model(args, logger):
tokenizer, model = setup_tokenizer(args, model)
generation_config = setup_generation_config(args, model, tokenizer)
if args.fp8:
model = setup_quantization(model)
model = setup_quantization(args, model)
init_end = time.perf_counter()
logger.info(f"Args: {args}")
logger.info(f"device: {args.device}, n_hpu: {args.world_size}, bf16: {model_dtype == torch.bfloat16}")
Expand Down
89 changes: 53 additions & 36 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@

def update(prev, cur, dim, idx, inp_seq_len):
orig_cur = cur
cur = cur.to(dtype=prev.dtype)
if prev.dtype == torch.float8_e4m3fn:
from habana_frameworks.torch.hpex.kernels.Fp8Ops import cast_to_fp8_v2
cur = cast_to_fp8_v2(cur, None, False, False, prev.dtype)[0]
if cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]:
# Initialize
prev[:, :, :inp_seq_len, :].copy_(cur)
Expand Down Expand Up @@ -92,35 +94,52 @@ def __init__(self):
def forward(self, x, y):
return torch.matmul(x, y)

class KVCache(torch.nn.Module):

def __init__(self):
super(KVCache, self).__init__()
self.cache = None
self.inp_seq_len = -1

def allocate(self, inp_seq_len, kv_cache_fp8, dtype, device, shape):
if self.cache is None or self.cache.shape != shape:
self.inp_seq_len = inp_seq_len
if kv_cache_fp8:
dtype = torch.float8_e4m3fn
self.cache = torch.zeros(shape, dtype=dtype, device=device)
else:
assert (
self.inp_seq_len == inp_seq_len
), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}"
self.cache.fill_(0)

def get_shape(self):
if self.cache is None:
return None
return self.cache.shape

def forward(self, cur, dim, idx):
return update(self.cache, cur, dim, idx, self.inp_seq_len)

class GaudiLlamaAttention(LlamaAttention):
def __init__(self, config: LlamaConfig):
super().__init__(config)

self.matmul_qk = Matmul()
self.matmul_av = Matmul()
self.past_key = None
self.past_value = None
self.k_cache = KVCache()
self.v_cache = KVCache()
self.inp_seq_len = -1

self.register_buffer("norm_factor", torch.tensor(1.0 / math.sqrt(self.head_dim)), persistent=False)


def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, kv_cache_fp8):
key_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim)
value_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim)
if self.past_key is None or self.past_key.shape != key_shape:
self.inp_seq_len = inp_seq_len
device = self.k_proj.weight.device
dtype = self.k_proj.weight.dtype
if kv_cache_fp8:
dtype = torch.float8_e4m3fn
self.past_key = torch.zeros(key_shape, dtype=dtype, device=device)
self.past_value = torch.zeros(value_shape, dtype=dtype, device=device)
else:
assert (
self.inp_seq_len == inp_seq_len
), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}"
self.past_key.fill_(0)
self.past_value.fill_(0)
cache_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim)
device = self.k_proj.weight.device
dtype = self.config.torch_dtype
self.k_cache.allocate(inp_seq_len, kv_cache_fp8, dtype, device, cache_shape)
self.v_cache.allocate(inp_seq_len, kv_cache_fp8, dtype, device, cache_shape)

def update_sincos_cache(self, seq_len):
# Call rotary emb forward() to update cos/sin cache when infering more than self.max_position_embeddings
Expand All @@ -135,14 +154,14 @@ def reorder(self, tensor, beam_idx, dim_a, dim_b):
tensor.copy_(updated)

def reorder_kv_cache(self, beam_idx: torch.LongTensor):
if self.past_key is None:
if self.k_cache.cache is None:
return (None, None)

head_dim = self.past_key.size(-1)
seq_length = self.past_key.size(-2)
self.reorder(self.past_key, beam_idx, seq_length, head_dim)
self.reorder(self.past_value, beam_idx, seq_length, head_dim)
return (self.past_key.shape, self.past_value.shape)
head_dim = self.k_cache.cache.size(-1)
seq_length = self.k_cache.cache.size(-2)
self.reorder(self.k_cache.cache, beam_idx, seq_length, head_dim)
self.reorder(self.v_cache.cache, beam_idx, seq_length, head_dim)
return (self.k_cache.cache.shape, self.v_cache.cache.shape)

def pre_attn_forward(
self,
Expand Down Expand Up @@ -212,17 +231,15 @@ def pre_attn_forward(
if past_key_value is not None or reuse_cache:
# reuse k, v, self_attention
if reuse_cache:
past_key = self.past_key
past_value = self.past_value
key_states = self.k_cache(key_states, 2, token_idx)
value_states = self.v_cache(value_states, 2, token_idx)
else:
past_key = past_key_value[0]
past_value = past_key_value[1]
key_states = update(past_key, key_states, 2, token_idx, self.inp_seq_len)
value_states = update(past_value, value_states, 2, token_idx, self.inp_seq_len)
key_states = update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len)
value_states = update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len)

if use_cache:
if reuse_cache:
past_key_value = (self.past_key.shape, self.past_value.shape)
past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape())
else:
past_key_value = (key_states.contiguous(), value_states.contiguous())
else:
Expand Down Expand Up @@ -289,11 +306,11 @@ def pre_attn_forward(
return attn_output, attn_weights, past_key_value

def attention_all_reduce(self, attn_output):
if self.o_proj.__class__ is ScopedLinearAllReduce:
if hasattr(self.o_proj, "all_reduce"):
self.o_proj.all_reduce(attn_output)

def post_attn_forward(self, attn_output):
if self.o_proj.__class__ is ScopedLinearAllReduce:
if hasattr(self.o_proj, "post_all_reduce"):
self.o_proj.post_all_reduce(attn_output)
return attn_output

Expand Down Expand Up @@ -322,13 +339,13 @@ def pre_mlp_forward(self, x):
return output

def mlp_all_reduce(self, x):
if self.down_proj.__class__ is ScopedLinearAllReduce:
if hasattr(self.down_proj, "all_reduce"):
self.down_proj.all_reduce(x)

def post_mlp_forward(self, x):
if self.config.pretraining_tp > 1:
return x
if self.down_proj.__class__ is ScopedLinearAllReduce:
if hasattr(self.down_proj, "post_all_reduce"):
return self.down_proj.post_all_reduce(x)
return x

Expand Down