From 3bae58d650429fa8495dc6cf6d7ea24f944ac967 Mon Sep 17 00:00:00 2001 From: Yan Tomsinsky <73292515+Yantom1@users.noreply.github.com> Date: Tue, 20 Feb 2024 14:49:26 +0200 Subject: [PATCH 1/6] Enable Llama2 70B to run with hqt on single card (#50) Add disk_offload flag that controls device_map=auto. Setting this flag enbales weights offload to disk when cpu memory runs OOM. Add const serialization path flag that gets a path for where to serialize const sections, so if there is no space on device to save all const sections they will be offloaded to disk. --- examples/text-generation/run_generation.py | 14 +++++++++++++- examples/text-generation/run_lm_eval.py | 3 +++ examples/text-generation/utils.py | 15 ++++++++++++--- 3 files changed, 28 insertions(+), 4 deletions(-) diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index 6b0b2e4695..b4d8654c07 100644 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -239,7 +239,16 @@ def setup_parser(parser): ) parser.add_argument("--temperature", default=1.0, type=float, help="Temperature value for text generation") parser.add_argument("--top_p", default=1.0, type=float, help="Top_p value for generating text via sampling") - + parser.add_argument( + '--const_serialization_path', + '--csp', + type=str, + help="Path to serialize const params. Const params will be held on disk memory instead of being allocated on host memory.") + parser.add_argument( + "--disk_offload", + action="store_true", + help="Whether to enable device map auto. In case no space left on cpu, weights will be offloaded to disk.", + ) args = parser.parse_args() if args.torch_compile: @@ -561,6 +570,9 @@ def generate_dataset(batch): import habana_quantization_toolkit habana_quantization_toolkit.finish_measurements(model) + if args.const_serialization_path and os.path.isdir(args.const_serialization_path): + import shutil + shutil.rmtree(args.const_serialization_path) if __name__ == "__main__": diff --git a/examples/text-generation/run_lm_eval.py b/examples/text-generation/run_lm_eval.py index 4ae8dcb26c..5fee4a4af9 100644 --- a/examples/text-generation/run_lm_eval.py +++ b/examples/text-generation/run_lm_eval.py @@ -176,6 +176,9 @@ def main(): import habana_quantization_toolkit habana_quantization_toolkit.finish_measurements(model) + if args.const_serialization_path and os.path.isdir(args.const_serialization_path): + import shutil + shutil.rmtree(args.const_serialization_path) if __name__ == "__main__": diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index e8c847c2f7..4ead35b026 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -98,12 +98,9 @@ def setup_distributed(args): def setup_quantization(args, model): import habana_frameworks.torch.core as htcore - from habana_frameworks.torch.core.quantization import _check_params_as_const, _mark_params_as_const from habana_frameworks.torch.hpu import hpu print("Initializing inference with quantization") - _mark_params_as_const(model) - _check_params_as_const(model) if not args.quant_config: hpu.enable_quantization() htcore.hpu_initialize(model) @@ -373,6 +370,10 @@ def initialize_model(args, logger): "revision": args.model_revision, "token": args.token, } + if args.disk_offload: + model_kwargs["device_map"] = "auto" + model_kwargs["offload_folder"] = "/tmp/offload_folder/" + model = ( setup_model(args, model_dtype, model_kwargs, logger) if not use_deepspeed @@ -380,6 +381,14 @@ def initialize_model(args, logger): ) tokenizer, model = setup_tokenizer(args, model) generation_config = setup_generation_config(args, model, tokenizer) + + if args.const_serialization_path: + import uuid + args.const_serialization_path = os.path.join(args.const_serialization_path + uuid.uuid4().hex) + os.makedirs(args.const_serialization_path) + from habana_frameworks.torch.hpu import enable_const_section_serialization + print("Serializing const params to {}".format(args.const_serialization_path)) + enable_const_section_serialization(args.const_serialization_path, False, True) if args.fp8: model = setup_quantization(args, model) init_end = time.perf_counter() From 5a7aa50d739b8568c5440b96fc235c310e54ce1a Mon Sep 17 00:00:00 2001 From: bgoldberg-habana <149692267+bgoldberg-habana@users.noreply.github.com> Date: Mon, 26 Feb 2024 15:12:50 +0200 Subject: [PATCH 2/6] llama fp8 - enable non reuse cache flow for fp8 (#64) * llama fp8 - enable non reuse cache flow for fp8 remove depracted kv cache fp8 flag Change-Id: Id76f94a127dee202376e8f27de7b28f58affedae * fixing lm eval Change-Id: I230fa53e7b49d8bb36397b063f652ba3def84600 * remove old quantization mode Change-Id: I538172f29870311349ed79d928cfacc60fb534e8 --- examples/text-generation/README.md | 1 - examples/text-generation/run_generation.py | 9 --- examples/text-generation/run_lm_eval.py | 3 +- examples/text-generation/utils.py | 24 +++--- .../generation/configuration_utils.py | 5 -- .../habana/transformers/generation/utils.py | 3 +- .../models/llama/modeling_llama.py | 75 +++++++++---------- 7 files changed, 48 insertions(+), 72 deletions(-) diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index 83a481970c..35cdd0da8c 100644 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -108,7 +108,6 @@ Here are a few settings you may be interested in: - `--attn_softmax_bf16` to run attention softmax layer in bfloat16 precision provided that the model (such as Llama) supports it - `--trim_logits` to calculate logits only for the last token in the first time step provided that the model (such as Llama) supports it - `--fp8` Enable Quantization to fp8 -- `--kv_cache_fp8` Deprecated - Store kv-cache in float8 when kv-cache is used. should not be used with HQT(The Quantization Toolkit) For example, you can reproduce the results presented in [this blog post](https://huggingface.co/blog/habana-gaudi-2-bloom) with the following command: ```bash diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index b4d8654c07..98577ebdae 100644 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -221,11 +221,6 @@ def setup_parser(parser): help="Preprocess on cpu, and some other optimizations. Useful to prevent recompilations when using dynamic prompts (simulate_dyn_prompt)", ) - parser.add_argument( - "--kv_cache_fp8", - action="store_true", - help="Store kv-cache in float8 when kv-cache is used. Can't use this argument together with QUANT_CONFIG env var", - ) parser.add_argument("--fp8", action="store_true", help="Enable Quantization to fp8") parser.add_argument( "--use_flash_attention", @@ -258,10 +253,6 @@ def setup_parser(parser): args.limit_hpu_graphs = False args.quant_config = os.getenv("QUANT_CONFIG", "") - if args.quant_config and args.kv_cache_fp8: - # can't use both quant_config and kv_cache_fp8, since quant_config may trigger kv cache quantization - # with habana quantization toolkit - raise parser.error("Can't use QUANT_CONFIG env var with kv_cache_fp8 argument") return args diff --git a/examples/text-generation/run_lm_eval.py b/examples/text-generation/run_lm_eval.py index 5fee4a4af9..4f90306354 100644 --- a/examples/text-generation/run_lm_eval.py +++ b/examples/text-generation/run_lm_eval.py @@ -134,8 +134,7 @@ def _model_call(self, inps): self.model.allocate_kv_cache( bs, bucket_length + 1, - bucket_length, - False, + bucket_length ) padding_length = bucket_length - seq_length inps = F.pad(inps, (0, padding_length), value=self.model.config.pad_token_id) diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index 4ead35b026..c1d53d3736 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -96,16 +96,20 @@ def setup_distributed(args): args.global_rank = int(os.getenv("RANK", "0")) -def setup_quantization(args, model): +def setup_inference(args, model): import habana_frameworks.torch.core as htcore - from habana_frameworks.torch.hpu import hpu - print("Initializing inference with quantization") - if not args.quant_config: - hpu.enable_quantization() + print("Initializing inference mode") htcore.hpu_initialize(model) return model +def setup_const_serialization(const_serialization_path): + import uuid + const_serialization_path = os.path.join(const_serialization_path + uuid.uuid4().hex) + os.makedirs(const_serialization_path) + from habana_frameworks.torch.hpu import enable_const_section_serialization + print("Serializing const params to {}".format(const_serialization_path)) + enable_const_section_serialization(const_serialization_path, False, True) def setup_env(args): # Will error if the minimal version of Transformers is not installed. Remove at your own risks. @@ -346,7 +350,6 @@ def setup_generation_config(args, model, tokenizer): generation_config.reduce_recompile = args.reduce_recompile if generation_config.reduce_recompile: assert generation_config.bucket_size > 0 - generation_config.kv_cache_fp8 = args.kv_cache_fp8 generation_config.use_flash_attention = args.use_flash_attention return generation_config @@ -383,14 +386,9 @@ def initialize_model(args, logger): generation_config = setup_generation_config(args, model, tokenizer) if args.const_serialization_path: - import uuid - args.const_serialization_path = os.path.join(args.const_serialization_path + uuid.uuid4().hex) - os.makedirs(args.const_serialization_path) - from habana_frameworks.torch.hpu import enable_const_section_serialization - print("Serializing const params to {}".format(args.const_serialization_path)) - enable_const_section_serialization(args.const_serialization_path, False, True) + setup_const_serialization(args.const_serialization_path) if args.fp8: - model = setup_quantization(args, model) + model = setup_inference(args, model) init_end = time.perf_counter() logger.info(f"Args: {args}") logger.info(f"device: {args.device}, n_hpu: {args.world_size}, bf16: {model_dtype == torch.bfloat16}") diff --git a/optimum/habana/transformers/generation/configuration_utils.py b/optimum/habana/transformers/generation/configuration_utils.py index de537dbe4f..facf4c6f36 100644 --- a/optimum/habana/transformers/generation/configuration_utils.py +++ b/optimum/habana/transformers/generation/configuration_utils.py @@ -27,10 +27,6 @@ class GaudiGenerationConfig(GenerationConfig): If negative (default=-1) pad to max if `static_shapes` is set. Else start with `shape = bucket_size * ceil(prompt_len/bucket_size)` and then grow space by `bucket_size` when needed. Only active if `static_shapes` is used. Can't be used with `reuse_cache`. - bucket_internal (`bool`, *optional*): - Split kv sequence into buckets in decode phase. It improves throughput when max_new_tokens is large. - kv_cache_fp8 (`bool`, *optional*): - Store kv-cache in float8 when kv-cache is used use_flash_attention (`bool`, *optional*): Whether to use flash attention optimization. flash_attention_recompute (`bool`, *optional*): @@ -48,6 +44,5 @@ def __init__(self, **kwargs): self.bucket_size = kwargs.get("bucket_size", -1) self.bucket_internal = kwargs.get("bucket_internal", None) self.reduce_recompile = kwargs.get("reduce_recompile", None) - self.kv_cache_fp8 = kwargs.get("kv_cache_fp8", None) self.use_flash_attention = kwargs.get("use_flash_attention", None) self.flash_attention_recompute = kwargs.get("flash_attention_recompute", None) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 42da95552e..eb3b78f9fe 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -694,8 +694,7 @@ def generate( unwrap_deepspeed_model(self).allocate_kv_cache( bs * generation_config.num_beams, calculated_max_length, - token_idx, - generation_config.kv_cache_fp8, + token_idx ) model_kwargs["kv_cache_len"] = calculated_max_length diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index b6411e6b54..d7d9dda57e 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -43,23 +43,6 @@ FusedSDPA = None -def update(prev, cur, dim, idx, inp_seq_len): - orig_cur = cur - if prev.dtype == torch.float8_e4m3fn: - from habana_frameworks.torch.hpex.kernels.Fp8Ops import cast_to_fp8_v2 - - cur = cast_to_fp8_v2(cur, None, False, False, prev.dtype)[0] - if cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]: - # Initialize - prev[:, :, :inp_seq_len, :].copy_(cur) - return orig_cur - assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" - if idx is not None: - prev.index_copy_(dim, idx - 1, cur) - prev_cast = prev.to(orig_cur.dtype) - return prev_cast - else: - return torch.cat((prev, cur), dim=dim) def gaudi_llama_rmsnorm_forward(self, hidden_states): @@ -133,11 +116,9 @@ def __init__(self): self.cache = None self.inp_seq_len = -1 - def allocate(self, inp_seq_len, kv_cache_fp8, dtype, device, shape): + def allocate(self, inp_seq_len, dtype, device, shape): if self.cache is None or self.cache.shape != shape: self.inp_seq_len = inp_seq_len - if kv_cache_fp8: - dtype = torch.float8_e4m3fn self.cache = torch.zeros(shape, dtype=dtype, device=device) else: assert ( @@ -145,13 +126,29 @@ def allocate(self, inp_seq_len, kv_cache_fp8, dtype, device, shape): ), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}" self.cache.fill_(0) + def update(self, prev, cur, dim, idx, inp_seq_len): + orig_cur = cur + if prev.shape == cur.shape: + prev.copy_(cur) + return orig_cur + if cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]: + # Initialize + prev[:, :, :inp_seq_len, :].copy_(cur) + return orig_cur + assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" + if idx is not None: + prev.index_copy_(dim, idx - 1, cur) + return prev + else: + return torch.cat((prev, cur), dim=dim) + def get_shape(self): if self.cache is None: return None return self.cache.shape def forward(self, cur, dim, idx): - return update(self.cache, cur, dim, idx, self.inp_seq_len) + return self.update(self.cache, cur, dim, idx, self.inp_seq_len) class GaudiLlamaAttention(LlamaAttention): @@ -165,12 +162,12 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): self.inp_seq_len = -1 self.norm_factor = 1.0 / math.sqrt(self.head_dim) - def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, kv_cache_fp8): + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): cache_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim) device = self.k_proj.weight.device dtype = self.config.torch_dtype - self.k_cache.allocate(inp_seq_len, kv_cache_fp8, dtype, device, cache_shape) - self.v_cache.allocate(inp_seq_len, kv_cache_fp8, dtype, device, cache_shape) + self.k_cache.allocate(inp_seq_len, dtype, device, cache_shape) + self.v_cache.allocate(inp_seq_len, dtype, device, cache_shape) def update_sincos_cache(self, seq_len): # Call rotary emb forward() to update cos/sin cache when infering more than self.max_position_embeddings @@ -276,26 +273,25 @@ def pre_attn_forward( cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_customized_rope(query_states, key_states, cos, sin, position_ids) - if past_key_value is not None or reuse_cache: + if use_cache: # reuse k, v, self_attention if reuse_cache: key_states = self.k_cache(key_states, 2, token_idx) value_states = self.v_cache(value_states, 2, token_idx) + past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape()) else: - key_states = update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) - value_states = update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len) + if past_key_value is None: + past_key = torch.zeros(key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device) + past_value = torch.zeros(key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device) + past_key_value = (past_key, past_value) + key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) + value_states = self.v_cache.update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len) if cache_idx is not None and q_len == 1: key_states = key_states[:, :, :cache_idx, :] value_states = value_states[:, :, :cache_idx, :] attention_mask = attention_mask[:, :, :, :cache_idx] kv_seq_len = key_states.shape[-2] - - if use_cache: - if reuse_cache: - past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape()) - else: - past_key_value = (key_states.contiguous(), value_states.contiguous()) else: past_key_value = None @@ -433,8 +429,8 @@ def __init__(self, config: LlamaConfig, layer_idx: int): self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, kv_cache_fp8): - self.self_attn.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len, kv_cache_fp8) + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + self.self_attn.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) def reorder_kv_cache(self, beam_idx: torch.LongTensor): return self.self_attn.reorder_kv_cache(beam_idx) @@ -552,9 +548,9 @@ def post_mlp(self, input, residual): class GaudiLlamaModel(LlamaModel): - def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, kv_cache_fp8): + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): for layer in self.layers: - layer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len, kv_cache_fp8) + layer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) def reorder_kv_cache(self, beam_idx: torch.LongTensor): return tuple(layer.reorder_kv_cache(beam_idx) for layer in self.layers) @@ -742,9 +738,8 @@ class GaudiLlamaForCausalLM(LlamaForCausalLM): - add new args reuse_cache """ - def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, kv_cache_fp8): - self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len, kv_cache_fp8) - self.kv_cache_len = max_seq_len + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) def reorder_kv_cache(self, beam_idx: torch.LongTensor): return self.model.reorder_kv_cache(beam_idx) From 157634d9c7f8ba48cb3787c5e4476ed84616eeb7 Mon Sep 17 00:00:00 2001 From: Danny Semiat Date: Thu, 7 Mar 2024 15:32:34 +0200 Subject: [PATCH 3/6] Remove setup_inference() in utils.py --- examples/text-generation/utils.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index c1d53d3736..7242b3e670 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -96,13 +96,6 @@ def setup_distributed(args): args.global_rank = int(os.getenv("RANK", "0")) -def setup_inference(args, model): - import habana_frameworks.torch.core as htcore - - print("Initializing inference mode") - htcore.hpu_initialize(model) - return model - def setup_const_serialization(const_serialization_path): import uuid const_serialization_path = os.path.join(const_serialization_path + uuid.uuid4().hex) @@ -388,7 +381,9 @@ def initialize_model(args, logger): if args.const_serialization_path: setup_const_serialization(args.const_serialization_path) if args.fp8: - model = setup_inference(args, model) + import habana_frameworks.torch.core as htcore + print("Initializing inference mode") + htcore.hpu_initialize(model) init_end = time.perf_counter() logger.info(f"Args: {args}") logger.info(f"device: {args.device}, n_hpu: {args.world_size}, bf16: {model_dtype == torch.bfloat16}") From c55cb614f13fa57103aa6cec3930bb385d8eb467 Mon Sep 17 00:00:00 2001 From: Danny Semiat Date: Mon, 11 Mar 2024 11:18:39 +0200 Subject: [PATCH 4/6] stylize code --- examples/text-generation/run_generation.py | 8 +++++--- examples/text-generation/run_lm_eval.py | 7 ++----- examples/text-generation/utils.py | 6 +++++- optimum/habana/transformers/generation/utils.py | 4 +--- .../habana/transformers/models/llama/modeling_llama.py | 6 +++--- 5 files changed, 16 insertions(+), 15 deletions(-) diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index 98577ebdae..1f503ed5e1 100644 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -235,10 +235,11 @@ def setup_parser(parser): parser.add_argument("--temperature", default=1.0, type=float, help="Temperature value for text generation") parser.add_argument("--top_p", default=1.0, type=float, help="Top_p value for generating text via sampling") parser.add_argument( - '--const_serialization_path', - '--csp', + "--const_serialization_path", + "--csp", type=str, - help="Path to serialize const params. Const params will be held on disk memory instead of being allocated on host memory.") + help="Path to serialize const params. Const params will be held on disk memory instead of being allocated on host memory.", + ) parser.add_argument( "--disk_offload", action="store_true", @@ -563,6 +564,7 @@ def generate_dataset(batch): habana_quantization_toolkit.finish_measurements(model) if args.const_serialization_path and os.path.isdir(args.const_serialization_path): import shutil + shutil.rmtree(args.const_serialization_path) diff --git a/examples/text-generation/run_lm_eval.py b/examples/text-generation/run_lm_eval.py index 4f90306354..3b108cf3f5 100644 --- a/examples/text-generation/run_lm_eval.py +++ b/examples/text-generation/run_lm_eval.py @@ -131,11 +131,7 @@ def _model_call(self, inps): if self.options.static_shapes: bucket_length = self.find_bucket(seq_length) if self.options.use_cache and self.options.reuse_cache: - self.model.allocate_kv_cache( - bs, - bucket_length + 1, - bucket_length - ) + self.model.allocate_kv_cache(bs, bucket_length + 1, bucket_length) padding_length = bucket_length - seq_length inps = F.pad(inps, (0, padding_length), value=self.model.config.pad_token_id) logits = self.model(inps.to(self._device), **self.model_inputs)["logits"].cpu() @@ -177,6 +173,7 @@ def main(): habana_quantization_toolkit.finish_measurements(model) if args.const_serialization_path and os.path.isdir(args.const_serialization_path): import shutil + shutil.rmtree(args.const_serialization_path) diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index 7242b3e670..018472c636 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -98,12 +98,15 @@ def setup_distributed(args): def setup_const_serialization(const_serialization_path): import uuid - const_serialization_path = os.path.join(const_serialization_path + uuid.uuid4().hex) + + const_serialization_path = os.path.join(const_serialization_path + uuid.uuid4().hex) os.makedirs(const_serialization_path) from habana_frameworks.torch.hpu import enable_const_section_serialization + print("Serializing const params to {}".format(const_serialization_path)) enable_const_section_serialization(const_serialization_path, False, True) + def setup_env(args): # Will error if the minimal version of Transformers is not installed. Remove at your own risks. check_min_version("4.34.0") @@ -382,6 +385,7 @@ def initialize_model(args, logger): setup_const_serialization(args.const_serialization_path) if args.fp8: import habana_frameworks.torch.core as htcore + print("Initializing inference mode") htcore.hpu_initialize(model) init_end = time.perf_counter() diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index eb3b78f9fe..95ea13388c 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -692,9 +692,7 @@ def generate( bs, _ = input_ids.shape if not is_greedy_or_beam_and_bucket: unwrap_deepspeed_model(self).allocate_kv_cache( - bs * generation_config.num_beams, - calculated_max_length, - token_idx + bs * generation_config.num_beams, calculated_max_length, token_idx ) model_kwargs["kv_cache_len"] = calculated_max_length diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index d7d9dda57e..2194280fb0 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -43,8 +43,6 @@ FusedSDPA = None - - def gaudi_llama_rmsnorm_forward(self, hidden_states): """ Copied from LlamaRMSNorm.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py @@ -282,7 +280,9 @@ def pre_attn_forward( else: if past_key_value is None: past_key = torch.zeros(key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device) - past_value = torch.zeros(key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device) + past_value = torch.zeros( + key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device + ) past_key_value = (past_key, past_value) key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) value_states = self.v_cache.update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len) From 36c529f763bdaaf5319581f83a1c83ac0a2758fa Mon Sep 17 00:00:00 2001 From: Danny Semiat Date: Sun, 17 Mar 2024 09:16:55 +0200 Subject: [PATCH 5/6] Re-add bucket_internal, mistakenly removed from cherry-pick Update configuration_utils.py --- optimum/habana/transformers/generation/configuration_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/optimum/habana/transformers/generation/configuration_utils.py b/optimum/habana/transformers/generation/configuration_utils.py index facf4c6f36..ca5523d68b 100644 --- a/optimum/habana/transformers/generation/configuration_utils.py +++ b/optimum/habana/transformers/generation/configuration_utils.py @@ -27,6 +27,8 @@ class GaudiGenerationConfig(GenerationConfig): If negative (default=-1) pad to max if `static_shapes` is set. Else start with `shape = bucket_size * ceil(prompt_len/bucket_size)` and then grow space by `bucket_size` when needed. Only active if `static_shapes` is used. Can't be used with `reuse_cache`. + bucket_internal (`bool`, *optional*): + Split kv sequence into buckets in decode phase. It improves throughput when max_new_tokens is large. use_flash_attention (`bool`, *optional*): Whether to use flash attention optimization. flash_attention_recompute (`bool`, *optional*): From b6cf31d83c35636f82d204834fb2562f80314312 Mon Sep 17 00:00:00 2001 From: Danny Semiat Date: Tue, 19 Mar 2024 10:00:08 +0200 Subject: [PATCH 6/6] add ENABLE_CONST_MARKING flag in OH --- examples/text-generation/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index 018472c636..32831cf73e 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -387,7 +387,9 @@ def initialize_model(args, logger): import habana_frameworks.torch.core as htcore print("Initializing inference mode") - htcore.hpu_initialize(model) + const_marking = os.getenv("ENABLE_CONST_MARKING", "True") + if const_marking == "True": + htcore.hpu_initialize(model) init_end = time.perf_counter() logger.info(f"Args: {args}") logger.info(f"device: {args.device}, n_hpu: {args.world_size}, bf16: {model_dtype == torch.bfloat16}")