From 38cb2c15ab0d86819dd1e935c31b3e65c1b52e6f Mon Sep 17 00:00:00 2001 From: Adam Stachowicz Date: Tue, 2 Sep 2025 20:57:47 +0300 Subject: [PATCH] Add FP8 postprocess_measure.py --- examples/text-generation/README.md | 60 +++++++- .../quantization_tools/postprocess_measure.py | 145 ++++++++++++++++++ 2 files changed, 203 insertions(+), 2 deletions(-) create mode 100644 examples/text-generation/quantization_tools/postprocess_measure.py diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index 73e9ad0e35..6fdfca1e5b 100755 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -393,11 +393,62 @@ PT_ENABLE_INT64_SUPPORT=1 python ../gaudi_spawn.py --world_size 8 run_generatio ### Running with FP8 -Llama2-70b, Llama2-7b, Llama3-70b, Llama3-8b, Mixtral-8x7B, Falcon-180B and Llama3-405B in FP8 are enabled using the [Intel Neural Compressor (INC)](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_FP8.html), which provides model measurement and quantization capabilities in PyTorch. From synapse 1.17 / optimum-habana 1.13 release, INC is used by default for measuring and quantization. Habana Quantization Toolkit (HQT), which was used earlier, will be removed in future releases. To use HQT, disable INC by setting the following environment variable: `USE_INC=0`. +Llama2-70b, Llama2-7b, Llama3-70b, Llama3-8b, Mixtral-8x7B, Falcon-180B and Llama3-405B in FP8 are enabled using the [Intel Neural Compressor (INC)](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_FP8.html), which provides model measurement and quantization capabilities in PyTorch. +From synapse 1.17 / optimum-habana 1.13 release, INC is used by default for measuring and quantization. Habana Quantization Toolkit (HQT), which was used earlier, will be removed in future releases. To use HQT, disable INC by setting the following environment variable: `USE_INC=0`. More information on enabling fp8 in SynapseAI is available here: https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_FP8.html +#### Post-processing measurement artifacts before quantization + +After the measurement phase finishes (i.e. after running with a QUANT_CONFIG that contains `"mode": "measure"`), you must run the helper script `quantization_tools/postprocess_measure.py` to fix cache input mappings in the collected measurement JSON/NPZ files. The quantization phase (with a `"mode": "quantize"` QUANT_CONFIG) should then use the post‑processed outputs. + +Why needed: +- Ensures K/V cache tensors are properly associated for attention layers (standard, flash attention, DeepSeek MLA, etc.). +- Prevents incorrect scaling statistics that could degrade accuracy. + +Usage: +```bash +python quantization_tools/postprocess_measure.py \ + -m \ + -o \ + [--deepseek] # add only for DeepSeek / MLA style models +``` + +Example end‑to‑end (Mixtral): +1. Measurement: +```bash +PT_HPU_LAZY_MODE=1 QUANT_CONFIG=./quantization_config/maxabs_measure.json \ +python run_generation.py \ + --model_name_or_path mistralai/Mixtral-8x7B-v0.1 \ + --use_hpu_graphs --use_kv_cache --limit_hpu_graphs \ + --bucket_size 128 --max_new_tokens 128 --batch_size 1 --bf16 +``` +2. Post-process: +```bash +python quantization_tools/postprocess_measure.py \ + -m ./measurements_mixtral_bs1 \ + -o ./measurements_mixtral_bs1_fixed +``` +(Replace the measurement directory names above with the one actually produced in your run.) +3. Quantize (now using the fixed measurements): +```bash +PT_HPU_LAZY_MODE=1 QUANT_CONFIG=./quantization_config/maxabs_quant_mixtral.json \ +MEASUREMENTS_DIR=./measurements_mixtral_bs1_fixed \ +python run_generation.py \ + --model_name_or_path mistralai/Mixtral-8x7B-v0.1 \ + --use_hpu_graphs --use_kv_cache --limit_hpu_graphs \ + --bucket_size 128 --max_new_tokens 2048 --batch_size 16 --bf16 +``` +(If your quant config requires an environment variable or a path field for measurements, point it to the fixed directory.) + +For DeepSeek models add --deepseek: +```bash +python quantization_tools/postprocess_measure.py -m -o --deepseek +``` + +#### Measuring Tensor Quantization Statistics Examples + Here is an example to measure the tensor quantization statistics on Mixtral-8x7B with 1 card: ```bash PT_HPU_LAZY_MODE=1 QUANT_CONFIG=./quantization_config/maxabs_measure.json python run_generation.py \ @@ -411,6 +462,8 @@ PT_HPU_LAZY_MODE=1 QUANT_CONFIG=./quantization_config/maxabs_measure.json python --bf16 ``` +(After this measurement run, execute the post-processing step before proceeding.) + Here is an example to quantize the model based on previous measurements for Mixtral-8x7B with 1 card: ```bash PT_HPU_LAZY_MODE=1 QUANT_CONFIG=./quantization_config/maxabs_quant_mixtral.json python run_generation.py \ @@ -442,6 +495,7 @@ PT_HPU_LAZY_MODE=1 QUANT_CONFIG=./quantization_config/maxabs_measure_include_out --flash_attention_causal_mask \ --trust_remote_code ``` +> Run postprocess_measure.py on the produced measurement directory before executing the corresponding quantization command below. Here is an example to quantize the model based on previous measurements for Falcon-180B with 8 cards: ```bash @@ -480,6 +534,7 @@ PT_HPU_LAZY_MODE=1 QUANT_CONFIG=./quantization_config/maxabs_measure_include_out --flash_attention_causal_mask \ --trust_remote_code ``` +> Post-process the measurements prior to quantization. Here is an example to quantize the model based on previous measurements for Llama3-405B with 8 cards: > Please note that Llama3-405B requires minimum 16 cards Gaudi2 and 8 cards Gaudi3. @@ -502,7 +557,6 @@ PT_HPU_LAZY_MODE=1 QUANT_CONFIG=./quantization_config/maxabs_quant.json python . ``` Here is an example to measure the tensor quantization statistics on Llama3-8b with 1 card: - ```bash PT_HPU_LAZY_MODE=1 QUANT_CONFIG=./quantization_config/maxabs_measure.json python run_lm_eval.py \ -o acc_Llama3-8b_bs1_measure.txt \ @@ -516,6 +570,7 @@ PT_HPU_LAZY_MODE=1 QUANT_CONFIG=./quantization_config/maxabs_measure.json python --bf16 \ --trust_remote_code ``` +> Run postprocess_measure.py before the quantization example that follows. Here is an example to quantize the model based on previous measurements for Llama3-8b with 1 card: ```bash @@ -543,6 +598,7 @@ PT_HPU_LAZY_MODE=1 QUANT_CONFIG=./quantization_config/maxabs_measure.json python --bf16 \ --sdp_on_bf16 ``` +> Post-process measurements before quantization. Here is an example to quantize the model based on previous measurements for gemma with 1 card: ```bash diff --git a/examples/text-generation/quantization_tools/postprocess_measure.py b/examples/text-generation/quantization_tools/postprocess_measure.py new file mode 100644 index 0000000000..3ebf1c9371 --- /dev/null +++ b/examples/text-generation/quantization_tools/postprocess_measure.py @@ -0,0 +1,145 @@ +############################################################################### +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company +############################################################################### +import argparse +import json +import os +import sys + +import numpy as np + + +def fix_cache_inputs(json_data, args): + for layer_index in range(len(json_data['Nodes'])): + matmul_av_input = None + v_cache_input = None + matmul_qk_input = None + k_cache_input = None + # Flash attention case + fsdpa_k_input = None + fsdpa_v_input = None + # OH + oh_k_cache_input = None + oh_v_cache_input = None + + attn_name = "attn" + k_cache_name = "k_cache" + v_cache_name = "v_cache" + if args.deepseek: + print(f"Handling deepseek model") + attn_name = "mla_attn" + k_cache_name = "latent_cache_k" + + for node_name, node_info in json_data['Nodes'].items(): + if f'model.layers.{layer_index}.self_attn.{attn_name}.impl.matmul_av' in node_name: + matmul_av_input = node_info['inputs'][1] + if f'model.layers.{layer_index}.self_attn.{attn_name}.impl.{v_cache_name}' in node_name: + v_cache_input = node_info['inputs'][0] + if f'model.layers.{layer_index}.self_attn.{attn_name}.impl.matmul_qk' in node_name: + matmul_qk_input = node_info['inputs'][1] + if f'model.layers.{layer_index}.self_attn.{attn_name}.impl.{k_cache_name}' in node_name: + k_cache_input = node_info['inputs'][0] + # Flash attention case + if f'model.layers.{layer_index}.self_attn.fused_scaled_dot_product_attention' in node_name: + fsdpa_k_input = node_info['inputs'][1] + fsdpa_v_input = node_info['inputs'][2] + # Optimum-habana case + if f'model.layers.{layer_index}.self_attn.{k_cache_name}' in node_name: + oh_k_cache_input = node_info['inputs'][0] + if f'model.layers.{layer_index}.self_attn.{v_cache_name}' in node_name: + oh_v_cache_input = node_info['inputs'][0] + + if matmul_av_input != v_cache_input: + if args.deepseek: + # For deepseek, there is one tensor for k_cache and v_cache + json_data['Nodes'][f'model.layers.{layer_index}.self_attn.{attn_name}.impl.matmul_av']['inputs'][1] = k_cache_input + else: + json_data['Nodes'][f'model.layers.{layer_index}.self_attn.attn.impl.matmul_av']['inputs'][1] = v_cache_input + if matmul_qk_input != k_cache_input: + json_data['Nodes'][f'model.layers.{layer_index}.self_attn.attn.impl.matmul_qk']['inputs'][1] = k_cache_input + + # Flash attention + if fsdpa_k_input != oh_k_cache_input: + json_data['Nodes'][f'model.layers.{layer_index}.self_attn.fused_scaled_dot_product_attention']['inputs'][1] = oh_k_cache_input + if fsdpa_v_input != oh_v_cache_input: + json_data['Nodes'][f'model.layers.{layer_index}.self_attn.fused_scaled_dot_product_attention']['inputs'][2] = oh_v_cache_input + + return json_data + + +def parse_args(args): + parser = argparse.ArgumentParser( + description="Run the measurements parser", formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "-m", "--measurements", type=str, help="full path to the directory of the measurements that should be fixed" + ) + parser.add_argument( + "-o", + "--out", + type=str, + default=os.getcwd(), + help="path to the directory where the fixed measurements will be written", + ) + parser.add_argument( + "-d", + "--deepseek", + action="store_true", + help="if handle deepseek models, please set this flag", + ) + return parser.parse_args(args) + + +def main(args): + args = parse_args(args) + output_path = args.out + if not os.path.exists(output_path): + os.makedirs(output_path) + measurements_path = args.measurements + measurements_paths = os.listdir(measurements_path) + measurements_paths_ranges = [measurement_path for measurement_path in measurements_paths if measurement_path.endswith( + ".json") and 'MAXABS_HW' not in measurement_path and "mod_list" not in measurement_path] + measurements_paths_scales = [measurement_path for measurement_path in measurements_paths if measurement_path.endswith( + ".json") and 'MAXABS_HW' in measurement_path and "mod_list" not in measurement_path] + print(measurements_paths_ranges) + print(measurements_paths_scales) + for measurement in measurements_paths_ranges + measurements_paths_scales: + fixed_json_path = os.path.join( + output_path, f"{measurement.split(os.sep)[-1]}") + with open(fixed_json_path, "w") as fixed_json_file: + with open(os.path.join(measurements_path, measurement), "r") as json_file: + data_to_fix = json.load(json_file) + fixed_data = fix_cache_inputs(data_to_fix, args) + json.dump(fixed_data, fixed_json_file) + print("") + print("measurement=", measurement, flush=True) + print("measurements_paths_scales=", + measurements_paths_scales, flush=True) + if measurement in measurements_paths_ranges + measurements_paths_scales: + global_rank = fixed_data["GlobalRank"] + local_rank = fixed_data["LocalRank"] + mode = fixed_data["Mode"] + nodes = fixed_data["Nodes"] + layers = {} + fixed_npz_path = fixed_json_path.replace(".json", ".npz") + for layer, dlayer in nodes.items(): + layers[layer] = {} + layers[layer]["inputs"] = [ + np.array(x) for x in dlayer["inputs"]] + if dlayer.get("outputs") is not None: + layers[layer]["outputs"] = [ + np.array(x) for x in dlayer["outputs"]] + if dlayer.get("params") is not None and dlayer["params"].get("weight") is not None: + layers[layer]["params"] = {} + layers[layer]["params"]["weight"] = np.array( + dlayer["params"]["weight"]) + df = {"GlobalRank": global_rank, + "LocalRank": local_rank, "Mode": mode, "Nodes": layers} + with open(fixed_npz_path, "w"): + np.savez(fixed_npz_path, df) + + print("finished fix_measurements script") + + +if __name__ == "__main__": + main(sys.argv[1:]) \ No newline at end of file