diff --git a/examples/text-generation/quantization_config/act_maxabs_hw_weights_pcs_maxabs_pow2_quant.json b/examples/text-generation/quantization_config/act_maxabs_hw_weights_pcs_maxabs_pow2_quant.json new file mode 100644 index 0000000000..3b69844e44 --- /dev/null +++ b/examples/text-generation/quantization_config/act_maxabs_hw_weights_pcs_maxabs_pow2_quant.json @@ -0,0 +1,10 @@ +{ + "method": "HOOKS", + "mode": "QUANTIZE", + "observer": "maxabs", + "scale_method": "ACT_MAXABS_HW_WEIGHTS_PCS_MAXABS_POW2", + "whitelist": {"types": [], "names": []}, + "blacklist": {"types": [], "names": ["lm_head"]}, + "dump_stats_path": "./hqt_output/measure", + "dump_stats_xlsx_path": "./hqt_output/measure/fp8stats.xlsx" +} diff --git a/examples/text-generation/quantization_config/maxabs_measure.json b/examples/text-generation/quantization_config/maxabs_measure.json new file mode 100644 index 0000000000..8db84a3b21 --- /dev/null +++ b/examples/text-generation/quantization_config/maxabs_measure.json @@ -0,0 +1,9 @@ +{ + "method": "HOOKS", + "mode": "MEASURE", + "observer": "maxabs", + "whitelist": {"types": [], "names": []}, + "blacklist": {"types": [], "names": ["lm_head"]}, + "dump_stats_path": "./hqt_output/measure", + "dump_stats_xlsx_path": "./hqt_output/measure/fp8stats.xlsx" +} \ No newline at end of file diff --git a/examples/text-generation/quantization_config/maxabs_quant.json b/examples/text-generation/quantization_config/maxabs_quant.json new file mode 100644 index 0000000000..b83a880d65 --- /dev/null +++ b/examples/text-generation/quantization_config/maxabs_quant.json @@ -0,0 +1,10 @@ +{ + "method": "HOOKS", + "mode": "QUANTIZE", + "observer": "maxabs", + "scale_method": "maxabs_hw", + "whitelist": {"types": [], "names": []}, + "blacklist": {"types": [], "names": ["lm_head"]}, + "dump_stats_path": "./hqt_output/measure", + "dump_stats_xlsx_path": "./hqt_output/measure/fp8stats.xlsx" +} \ No newline at end of file diff --git a/examples/text-generation/quantization_config/unit_scale_quant.json b/examples/text-generation/quantization_config/unit_scale_quant.json new file mode 100644 index 0000000000..0517bd784d --- /dev/null +++ b/examples/text-generation/quantization_config/unit_scale_quant.json @@ -0,0 +1,10 @@ +{ + "method": "HOOKS", + "mode": "QUANTIZE", + "observer": "maxabs", + "scale_method": "unit_scale", + "whitelist": {"types": [], "names": []}, + "blacklist": {"types": [], "names": ["lm_head"]}, + "dump_stats_path": "./hqt_output/measure", + "dump_stats_xlsx_path": "./hqt_output/measure/fp8stats.xlsx" +} diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index ab308e7023..ecfbe15dca 100644 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -22,6 +22,7 @@ import json import logging import math +import os import time from itertools import cycle from pathlib import Path @@ -223,7 +224,7 @@ def setup_parser(parser): parser.add_argument( "--kv_cache_fp8", action="store_true", - help="Store kv-cache in float8 when kv-cache is used", + help="Store kv-cache in float8 when kv-cache is used. Can't use this argument together with QUANT_CONFIG env var", ) parser.add_argument("--fp8", action="store_true", help="Enable Quantization to fp8") parser.add_argument( @@ -239,6 +240,11 @@ def setup_parser(parser): if not args.use_hpu_graphs: args.limit_hpu_graphs = False + args.quant_config = os.getenv("QUANT_CONFIG", "") + if args.quant_config and args.kv_cache_fp8: + # can't use both quant_config and kv_cache_fp8, since quant_config may trigger kv cache quantization + # with habana quantization toolkit + raise parser.error("Can't use QUANT_CONFIG env var with kv_cache_fp8 argument") return args @@ -539,6 +545,9 @@ def generate_dataset(batch): if prompt_length > 0: print(f"Graph compilation duration = {compilation_duration} seconds") print(separator) + if args.quant_config: + import habana_quantization_toolkit + habana_quantization_toolkit.finish_measurements(model) if __name__ == "__main__": diff --git a/examples/text-generation/run_lm_eval.py b/examples/text-generation/run_lm_eval.py index f16c126af2..e59393d71b 100644 --- a/examples/text-generation/run_lm_eval.py +++ b/examples/text-generation/run_lm_eval.py @@ -78,6 +78,7 @@ def __init__(self, tokenizer, model, args, options): if self.model.config.model_type == "llama": self.model_inputs.update( { + "reuse_cache" : self.options.reuse_cache, "attn_softmax_bf16": self.options.attn_softmax_bf16, } ) @@ -125,10 +126,17 @@ def find_bucket(self, length): return [b for b in self.buckets if b >= length][0] def _model_call(self, inps): - seq_length = inps.shape[-1] + bs, seq_length = inps.shape padding_length = 0 if self.options.static_shapes: bucket_length = self.find_bucket(seq_length) + if self.options.use_cache and self.options.reuse_cache: + self.model.allocate_kv_cache( + bs, + bucket_length + 1, + bucket_length, + False, + ) padding_length = bucket_length - seq_length inps = F.pad(inps, (0, padding_length), value=self.model.config.pad_token_id) logits = self.model(inps.to(self._device), **self.model_inputs)["logits"].cpu() @@ -164,6 +172,9 @@ def main(): print("{:35} = {} GB".format(k[:-5].replace("_", " ").capitalize(), v)) json.dump(results, open(args.output_file, "w"), indent=2) print(json.dumps(results, indent=2)) + if args.quant_config: + import habana_quantization_toolkit + habana_quantization_toolkit.finish_measurements(model) if __name__ == "__main__": diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index 5c03de7dc6..f7e94d82e5 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -96,7 +96,7 @@ def setup_distributed(args): args.global_rank = int(os.getenv("RANK", "0")) -def setup_quantization(model): +def setup_quantization(args, model): import habana_frameworks.torch.core as htcore from habana_frameworks.torch.core.quantization import _check_params_as_const, _mark_params_as_const from habana_frameworks.torch.hpu import hpu @@ -104,8 +104,8 @@ def setup_quantization(model): print("Initializing inference with quantization") _mark_params_as_const(model) _check_params_as_const(model) - - hpu.enable_quantization() + if not args.quant_config: + hpu.enable_quantization() htcore.hpu_initialize(model) return model @@ -114,6 +114,8 @@ def setup_env(args): # Will error if the minimal version of Transformers is not installed. Remove at your own risks. check_min_version("4.34.0") check_optimum_habana_min_version("1.9.0.dev0") + #TODO: SW-167588 - WA for memory issue in hqt prep_model + os.environ.setdefault('EXPERIMENTAL_WEIGHT_SHARING', 'FALSE') if args.global_rank == 0: os.environ.setdefault("GRAPH_VISUALIZATION", "true") @@ -158,6 +160,9 @@ def setup_model(args, model_dtype, model_kwargs, logger): model = peft_model(args, model_dtype, logger, **model_kwargs) else: model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, torch_dtype=model_dtype, **model_kwargs) + if args.quant_config: + import habana_quantization_toolkit + habana_quantization_toolkit.prep_model(model) model = model.eval().to(args.device) if args.use_hpu_graphs: @@ -178,7 +183,7 @@ def setup_distributed_model(args, model_dtype, model_kwargs, logger): logger.info("DeepSpeed is enabled.") deepspeed.init_distributed(dist_backend="hccl") - config = AutoConfig.from_pretrained(args.model_name_or_path, **model_kwargs) + config = AutoConfig.from_pretrained(args.model_name_or_path, torch_dtype=model_dtype, **model_kwargs) load_to_meta = model_on_meta(config) if load_to_meta: @@ -227,6 +232,10 @@ def setup_distributed_model(args, model_dtype, model_kwargs, logger): model = model.module if model.config.model_type == "llama": patch_scoped_linear_all_reduce(model) + + if args.quant_config: + import habana_quantization_toolkit + habana_quantization_toolkit.prep_model(model) return model @@ -359,7 +368,7 @@ def initialize_model(args, logger): tokenizer, model = setup_tokenizer(args, model) generation_config = setup_generation_config(args, model, tokenizer) if args.fp8: - model = setup_quantization(model) + model = setup_quantization(args, model) init_end = time.perf_counter() logger.info(f"Args: {args}") logger.info(f"device: {args.device}, n_hpu: {args.world_size}, bf16: {model_dtype == torch.bfloat16}") diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index d0e217c38d..136da16e84 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -39,7 +39,9 @@ def update(prev, cur, dim, idx, inp_seq_len): orig_cur = cur - cur = cur.to(dtype=prev.dtype) + if prev.dtype == torch.float8_e4m3fn: + from habana_frameworks.torch.hpex.kernels.Fp8Ops import cast_to_fp8_v2 + cur = cast_to_fp8_v2(cur, None, False, False, prev.dtype)[0] if cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]: # Initialize prev[:, :, :inp_seq_len, :].copy_(cur) @@ -92,6 +94,32 @@ def __init__(self): def forward(self, x, y): return torch.matmul(x, y) +class KVCache(torch.nn.Module): + + def __init__(self): + super(KVCache, self).__init__() + self.cache = None + self.inp_seq_len = -1 + + def allocate(self, inp_seq_len, kv_cache_fp8, dtype, device, shape): + if self.cache is None or self.cache.shape != shape: + self.inp_seq_len = inp_seq_len + if kv_cache_fp8: + dtype = torch.float8_e4m3fn + self.cache = torch.zeros(shape, dtype=dtype, device=device) + else: + assert ( + self.inp_seq_len == inp_seq_len + ), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}" + self.cache.fill_(0) + + def get_shape(self): + if self.cache is None: + return None + return self.cache.shape + + def forward(self, cur, dim, idx): + return update(self.cache, cur, dim, idx, self.inp_seq_len) class GaudiLlamaAttention(LlamaAttention): def __init__(self, config: LlamaConfig): @@ -99,28 +127,19 @@ def __init__(self, config: LlamaConfig): self.matmul_qk = Matmul() self.matmul_av = Matmul() - self.past_key = None - self.past_value = None + self.k_cache = KVCache() + self.v_cache = KVCache() self.inp_seq_len = -1 + self.register_buffer("norm_factor", torch.tensor(1.0 / math.sqrt(self.head_dim)), persistent=False) + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, kv_cache_fp8): - key_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim) - value_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim) - if self.past_key is None or self.past_key.shape != key_shape: - self.inp_seq_len = inp_seq_len - device = self.k_proj.weight.device - dtype = self.k_proj.weight.dtype - if kv_cache_fp8: - dtype = torch.float8_e4m3fn - self.past_key = torch.zeros(key_shape, dtype=dtype, device=device) - self.past_value = torch.zeros(value_shape, dtype=dtype, device=device) - else: - assert ( - self.inp_seq_len == inp_seq_len - ), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}" - self.past_key.fill_(0) - self.past_value.fill_(0) + cache_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim) + device = self.k_proj.weight.device + dtype = self.config.torch_dtype + self.k_cache.allocate(inp_seq_len, kv_cache_fp8, dtype, device, cache_shape) + self.v_cache.allocate(inp_seq_len, kv_cache_fp8, dtype, device, cache_shape) def update_sincos_cache(self, seq_len): # Call rotary emb forward() to update cos/sin cache when infering more than self.max_position_embeddings @@ -135,14 +154,14 @@ def reorder(self, tensor, beam_idx, dim_a, dim_b): tensor.copy_(updated) def reorder_kv_cache(self, beam_idx: torch.LongTensor): - if self.past_key is None: + if self.k_cache.cache is None: return (None, None) - head_dim = self.past_key.size(-1) - seq_length = self.past_key.size(-2) - self.reorder(self.past_key, beam_idx, seq_length, head_dim) - self.reorder(self.past_value, beam_idx, seq_length, head_dim) - return (self.past_key.shape, self.past_value.shape) + head_dim = self.k_cache.cache.size(-1) + seq_length = self.k_cache.cache.size(-2) + self.reorder(self.k_cache.cache, beam_idx, seq_length, head_dim) + self.reorder(self.v_cache.cache, beam_idx, seq_length, head_dim) + return (self.k_cache.cache.shape, self.v_cache.cache.shape) def pre_attn_forward( self, @@ -212,17 +231,15 @@ def pre_attn_forward( if past_key_value is not None or reuse_cache: # reuse k, v, self_attention if reuse_cache: - past_key = self.past_key - past_value = self.past_value + key_states = self.k_cache(key_states, 2, token_idx) + value_states = self.v_cache(value_states, 2, token_idx) else: - past_key = past_key_value[0] - past_value = past_key_value[1] - key_states = update(past_key, key_states, 2, token_idx, self.inp_seq_len) - value_states = update(past_value, value_states, 2, token_idx, self.inp_seq_len) + key_states = update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) + value_states = update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len) if use_cache: if reuse_cache: - past_key_value = (self.past_key.shape, self.past_value.shape) + past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape()) else: past_key_value = (key_states.contiguous(), value_states.contiguous()) else: @@ -289,11 +306,11 @@ def pre_attn_forward( return attn_output, attn_weights, past_key_value def attention_all_reduce(self, attn_output): - if self.o_proj.__class__ is ScopedLinearAllReduce: + if hasattr(self.o_proj, "all_reduce"): self.o_proj.all_reduce(attn_output) def post_attn_forward(self, attn_output): - if self.o_proj.__class__ is ScopedLinearAllReduce: + if hasattr(self.o_proj, "post_all_reduce"): self.o_proj.post_all_reduce(attn_output) return attn_output @@ -322,13 +339,13 @@ def pre_mlp_forward(self, x): return output def mlp_all_reduce(self, x): - if self.down_proj.__class__ is ScopedLinearAllReduce: + if hasattr(self.down_proj, "all_reduce"): self.down_proj.all_reduce(x) def post_mlp_forward(self, x): if self.config.pretraining_tp > 1: return x - if self.down_proj.__class__ is ScopedLinearAllReduce: + if hasattr(self.down_proj, "post_all_reduce"): return self.down_proj.post_all_reduce(x) return x