From c27663e05d5de6dad899c07daa0c591611d20b0d Mon Sep 17 00:00:00 2001 From: wufeisheng Date: Thu, 14 Dec 2023 11:23:20 +0800 Subject: [PATCH] for develop --- llm/benchmark.sh | 2 + llm/predictor.py | 239 ++++++++++++++++++++++++++++++---------------- llm/run_static.sh | 1 + 3 files changed, 161 insertions(+), 81 deletions(-) diff --git a/llm/benchmark.sh b/llm/benchmark.sh index 0b3d4bcd3a28..ebfc94ff20b9 100644 --- a/llm/benchmark.sh +++ b/llm/benchmark.sh @@ -22,6 +22,8 @@ export FLAGS_fraction_of_gpu_memory_to_use=0.92 export FLAGS_use_autotune=1 export FLAGS_cublaslt_exhaustive_search_times=10 export FLAGS_cache_inference_while_scope=1 +export FLAGS_dynamic_static_unified_comm=0 +# export FLAGS_benchmark=1 model_dir=${1:-"checkpoints/llama65b_ptq_smooth_mp8"} src_len=${2:-1100} diff --git a/llm/predictor.py b/llm/predictor.py index f2551710790b..c669f62bb9a7 100644 --- a/llm/predictor.py +++ b/llm/predictor.py @@ -106,8 +106,11 @@ class PredictorArgument: block_size: int = field(default=64, metadata={"help": "the block size for cache_kvs."}) use_cachekv_int8: str = field( default="None", - metadata={"help": "If use_cachekv_int8 set as `dynamic`, dynamic cache kv quantization will be applied; if set as `static`, static cache kv will be applied"},) - + metadata={ + "help": "If use_cachekv_int8 set as `dynamic`, dynamic cache kv quantization will be applied; if set as `static`, static cache kv will be applied" + }, + ) + chat_template: str = field( default=None, metadata={ @@ -745,43 +748,67 @@ def __init__( self.cache_kvs[0].shape[-3], self.cache_kvs[0].shape[-1], ) - + self.inputs = {} self.pre_cache_length = 0 if config.export_precache: pre_cache_npy = np.load(config.prefix_path) self.pre_cache_length = pre_cache_npy.shape[-2] config.max_length -= self.pre_cache_length - self.pre_caches = [paddle.zeros([config.batch_size, self.num_attention_heads, self.pre_cache_length, self.head_dim], dtype=self.dtype) for _ in range(2 * self.num_layers)] + self.pre_caches = [ + paddle.zeros( + [config.batch_size, self.num_attention_heads, self.pre_cache_length, self.head_dim], + dtype=self.dtype, + ) + for _ in range(2 * self.num_layers) + ] print("pre_cache_length: ", self.pre_cache_length) for i in range(self.num_layers): - self.pre_caches[2 * i][:, :, :, :] = paddle.to_tensor(pre_cache_npy[i][0], dtype=self.dtype).unsqueeze(0) - self.pre_caches[2 * i + 1][:, :, :, :] = paddle.to_tensor(pre_cache_npy[i][1], dtype=self.dtype).unsqueeze(0) + self.pre_caches[2 * i][:, :, :, :] = paddle.to_tensor(pre_cache_npy[i][0], dtype=self.dtype).unsqueeze( + 0 + ) + self.pre_caches[2 * i + 1][:, :, :, :] = paddle.to_tensor( + pre_cache_npy[i][1], dtype=self.dtype + ).unsqueeze(0) self.inputs["pre_caches"] = self.pre_caches - pre_cache_mask = paddle.zeros(shape=[config.batch_size, 1, config.src_length, config.src_length + self.pre_cache_length], dtype=self.dtype) - pre_cache_mask[:, :, :, :self.pre_cache_length] = 1 - pre_cache_mask[:, :, :, self.pre_cache_length:] = paddle.tril(paddle.ones(shape=[config.batch_size, 1, config.src_length, config.src_length], dtype=self.dtype)) + pre_cache_mask = paddle.zeros( + shape=[config.batch_size, 1, config.src_length, config.src_length + self.pre_cache_length], + dtype=self.dtype, + ) + pre_cache_mask[:, :, :, : self.pre_cache_length] = 1 + pre_cache_mask[:, :, :, self.pre_cache_length :] = paddle.tril( + paddle.ones(shape=[config.batch_size, 1, config.src_length, config.src_length], dtype=self.dtype) + ) self.inputs["src_mask"] = (pre_cache_mask - 1) * 1e4 if config.use_cachekv_int8 == "dynamic": self.k_quant_scales = [ - paddle.zeros([config.batch_size, self.num_attention_heads], dtype="float32") for _ in range(self.num_layers) + paddle.zeros([config.batch_size, self.num_attention_heads], dtype="float32") + for _ in range(self.num_layers) ] self.v_quant_scales = [ - paddle.zeros([config.batch_size, self.num_attention_heads], dtype="float32") for _ in range(self.num_layers) + paddle.zeros([config.batch_size, self.num_attention_heads], dtype="float32") + for _ in range(self.num_layers) ] self.k_dequant_scales = [ - paddle.zeros([config.batch_size, self.num_attention_heads], dtype="float32") for _ in range(self.num_layers) + paddle.zeros([config.batch_size, self.num_attention_heads], dtype="float32") + for _ in range(self.num_layers) ] self.v_dequant_scales = [ - paddle.zeros([config.batch_size, self.num_attention_heads], dtype="float32") for _ in range(self.num_layers) + paddle.zeros([config.batch_size, self.num_attention_heads], dtype="float32") + for _ in range(self.num_layers) ] # not update self.inputs["cache_kvs"] = self.cache_kvs self.inputs["pre_ids"] = paddle.full([config.batch_size, self.total_max_length], -1, dtype="int64") - self.inputs["bad_tokens"] = paddle.to_tensor([-1, ], dtype="int64") + self.inputs["bad_tokens"] = paddle.to_tensor( + [ + -1, + ], + dtype="int64", + ) self.inputs["penalty_score"] = paddle.full(shape=[config.batch_size, 1], fill_value=1.0, dtype="float32") self.inputs["frequency_score"] = paddle.full(shape=[config.batch_size, 1], fill_value=0.0, dtype="float32") self.inputs["presence_score"] = paddle.full(shape=[config.batch_size, 1], fill_value=0.0, dtype="float32") @@ -791,7 +818,7 @@ def __init__( self.inputs["v_quant_scales"] = self.v_quant_scales self.inputs["k_dequant_scales"] = self.k_dequant_scales self.inputs["v_dequant_scales"] = self.v_dequant_scales - + if config.benchmark: min_length = config.max_length else: @@ -824,19 +851,19 @@ def __init__( self.inputs["step_idx"] = paddle.full(shape=[config.batch_size, 1], fill_value=0, dtype="int64") self.inputs["not_need_stop"] = paddle.full(shape=[1], fill_value=False, dtype="bool").cpu() self.inputs["stop_flags"] = paddle.full(shape=[config.batch_size, 1], fill_value=True, dtype="bool") - self.inputs['next_tokens'] = paddle.full(shape=[config.batch_size, 1], fill_value=-1, dtype="int64") - self.inputs['is_block_step'] = paddle.full(shape=[config.batch_size], fill_value=False, dtype="bool") - free_list = list(range(pre_max_block_num - 1, int(pre_max_block_num * 0.75) -1, -1)) - self.inputs['encoder_block_lens'] = paddle.full(shape=[config.batch_size], fill_value=0, dtype="int32") - self.inputs['step_block_list'] = paddle.full(shape=[config.batch_size], fill_value=-1, dtype="int32") - self.inputs['step_lens'] = paddle.full(shape=[1], fill_value=0, dtype="int32") - self.inputs['recover_block_list'] = paddle.full(shape=[config.batch_size], fill_value=-1, dtype="int32") - self.inputs['recover_lens'] = paddle.full(shape=[1], fill_value=0, dtype="int32") - self.inputs['need_block_list'] = paddle.full(shape=[config.batch_size], fill_value=-1, dtype="int32") - self.inputs['need_block_len'] = paddle.full(shape=[1], fill_value=0, dtype="int32") - self.inputs['used_list_len'] = paddle.full(shape=[config.batch_size], fill_value=0, dtype="int32") - self.inputs['free_list'] = paddle.to_tensor(free_list, dtype="int32") - self.inputs['free_list_len'] = paddle.full(shape=[1], fill_value=pre_max_block_num * 0.25, dtype="int32") + self.inputs["next_tokens"] = paddle.full(shape=[config.batch_size, 1], fill_value=-1, dtype="int64") + self.inputs["is_block_step"] = paddle.full(shape=[config.batch_size], fill_value=False, dtype="bool") + free_list = list(range(pre_max_block_num - 1, int(pre_max_block_num * 0.75) - 1, -1)) + self.inputs["encoder_block_lens"] = paddle.full(shape=[config.batch_size], fill_value=0, dtype="int32") + self.inputs["step_block_list"] = paddle.full(shape=[config.batch_size], fill_value=-1, dtype="int32") + self.inputs["step_lens"] = paddle.full(shape=[1], fill_value=0, dtype="int32") + self.inputs["recover_block_list"] = paddle.full(shape=[config.batch_size], fill_value=-1, dtype="int32") + self.inputs["recover_lens"] = paddle.full(shape=[1], fill_value=0, dtype="int32") + self.inputs["need_block_list"] = paddle.full(shape=[config.batch_size], fill_value=-1, dtype="int32") + self.inputs["need_block_len"] = paddle.full(shape=[1], fill_value=0, dtype="int32") + self.inputs["used_list_len"] = paddle.full(shape=[config.batch_size], fill_value=0, dtype="int32") + self.inputs["free_list"] = paddle.to_tensor(free_list, dtype="int32") + self.inputs["free_list_len"] = paddle.full(shape=[1], fill_value=pre_max_block_num * 0.25, dtype="int32") self.free_list = [i for i in range(self.max_block_nums)][::-1] self.used_list = [[] for _ in range(config.batch_size)] @@ -897,15 +924,17 @@ def _preprocess(self, source): self.inputs["penalty_score"][i : i + 1] = self.config.repetition_penalty self.inputs["frequency_score"][i : i + 1] = 0.0 self.inputs["presence_score"][i : i + 1] = 0.0 - self.inputs['top_p'][i : i + 1] = self.config.top_p - self.inputs['temperature'][i : i + 1] = self.config.temperature + self.inputs["top_p"][i : i + 1] = self.config.top_p + self.inputs["temperature"][i : i + 1] = self.config.temperature self.inputs["seq_lens_this_time"][i : i + 1] = length self.inputs["seq_lens_encoder"][i : i + 1] = length self.inputs["seq_lens_decoder"][i : i + 1] = 0 self.inputs["step_idx"][i : i + 1] = 0 self.inputs["stop_flags"][i : i + 1] = False reset_stop_value(self.inputs["not_need_stop"]) - need_block_nums = (length + self.config.max_length + self.pre_cache_length + self.block_size - 1) // self.block_size + need_block_nums = ( + length + self.config.max_length + self.pre_cache_length + self.block_size - 1 + ) // self.block_size for bi in range(need_block_nums): bi_now = self.free_list.pop() self.used_list[i].append(bi_now) @@ -937,11 +966,24 @@ def __init__( self.pre_cache_length = pre_cache_npy.shape[-2] config.max_length -= self.pre_cache_length for i in range(self.num_layers): - self.inputs["pre_caches_{}".format(2 * i)] = paddle.to_tensor(pre_cache_npy[i][0], dtype=config.dtype).unsqueeze(0).broadcast_to([config.batch_size, self.num_attention_heads, self.pre_cache_length, self.head_dim]) - self.inputs["pre_caches_{}".format(2 * i + 1)] = paddle.to_tensor(pre_cache_npy[i][1], dtype=config.dtype).unsqueeze(0).broadcast_to([config.batch_size, self.num_attention_heads, self.pre_cache_length, self.head_dim]) - pre_cache_mask = paddle.zeros(shape=[config.batch_size, 1, config.src_length, config.src_length + self.pre_cache_length], dtype=config.dtype) - pre_cache_mask[:, :, :, :self.pre_cache_length] = 1 - pre_cache_mask[:, :, :, self.pre_cache_length:] = paddle.tril(paddle.ones(shape=[config.batch_size, 1, config.src_length, config.src_length], dtype=config.dtype)) + self.inputs["pre_caches_{}".format(2 * i)] = ( + paddle.to_tensor(pre_cache_npy[i][0], dtype=config.dtype) + .unsqueeze(0) + .broadcast_to([config.batch_size, self.num_attention_heads, self.pre_cache_length, self.head_dim]) + ) + self.inputs["pre_caches_{}".format(2 * i + 1)] = ( + paddle.to_tensor(pre_cache_npy[i][1], dtype=config.dtype) + .unsqueeze(0) + .broadcast_to([config.batch_size, self.num_attention_heads, self.pre_cache_length, self.head_dim]) + ) + pre_cache_mask = paddle.zeros( + shape=[config.batch_size, 1, config.src_length, config.src_length + self.pre_cache_length], + dtype=config.dtype, + ) + pre_cache_mask[:, :, :, : self.pre_cache_length] = 1 + pre_cache_mask[:, :, :, self.pre_cache_length :] = paddle.tril( + paddle.ones(shape=[config.batch_size, 1, config.src_length, config.src_length], dtype=config.dtype) + ) self.inputs["src_mask"] = (pre_cache_mask - 1) * 1e4 self.cache_kvs = {} @@ -962,16 +1004,20 @@ def __init__( if config.use_cachekv_int8 == "dynamic": self.k_quant_scales = [ - paddle.zeros([config.batch_size, self.num_attention_heads], dtype="float32") for _ in range(self.num_layers) + paddle.zeros([config.batch_size, self.num_attention_heads], dtype="float32") + for _ in range(self.num_layers) ] self.v_quant_scales = [ - paddle.zeros([config.batch_size, self.num_attention_heads], dtype="float32") for _ in range(self.num_layers) + paddle.zeros([config.batch_size, self.num_attention_heads], dtype="float32") + for _ in range(self.num_layers) ] self.k_dequant_scales = [ - paddle.zeros([config.batch_size, self.num_attention_heads], dtype="float32") for _ in range(self.num_layers) + paddle.zeros([config.batch_size, self.num_attention_heads], dtype="float32") + for _ in range(self.num_layers) ] self.v_dequant_scales = [ - paddle.zeros([config.batch_size, self.num_attention_heads], dtype="float32") for _ in range(self.num_layers) + paddle.zeros([config.batch_size, self.num_attention_heads], dtype="float32") + for _ in range(self.num_layers) ] if config.benchmark: @@ -984,7 +1030,12 @@ def __init__( ) self.inputs["pre_ids"] = paddle.full([config.batch_size, self.total_max_length], -1, dtype="int64") - self.inputs["bad_tokens"] = paddle.to_tensor([-1, ], dtype="int64") + self.inputs["bad_tokens"] = paddle.to_tensor( + [ + -1, + ], + dtype="int64", + ) self.inputs["penalty_score"] = paddle.full(shape=[config.batch_size, 1], fill_value=1.0, dtype="float32") self.inputs["frequency_score"] = paddle.full(shape=[config.batch_size, 1], fill_value=0.0, dtype="float32") self.inputs["presence_score"] = paddle.full(shape=[config.batch_size, 1], fill_value=0.0, dtype="float32") @@ -1014,22 +1065,20 @@ def __init__( self.inputs["not_need_stop"] = paddle.full(shape=[1], fill_value=False, dtype="bool").cpu() self.inputs["stop_flags"] = paddle.full(shape=[config.batch_size, 1], fill_value=True, dtype="bool") - self.inputs['step_seq_lens_encoder'] = paddle.full(shape=[config.batch_size, 1], fill_value=0, dtype="int32") - self.inputs['next_tokens'] = paddle.full(shape=[config.batch_size, 1], fill_value=-1, dtype="int64") - self.inputs['is_block_step'] = paddle.full(shape=[config.batch_size], fill_value=False, dtype="bool") - free_list = list(range(pre_max_block_num - 1, int(pre_max_block_num * 0.75) -1, -1)) - self.inputs['encoder_block_lens'] = paddle.full(shape=[config.batch_size], fill_value=0, dtype="int32") - self.inputs['step_block_list'] = paddle.full(shape=[config.batch_size], fill_value=-1, dtype="int32") - self.inputs['step_lens'] = paddle.full(shape=[1], fill_value=0, dtype="int32") - self.inputs['recover_block_list'] = paddle.full(shape=[config.batch_size], fill_value=-1, dtype="int32") - self.inputs['recover_lens'] = paddle.full(shape=[1], fill_value=0, dtype="int32") - self.inputs['need_block_list'] = paddle.full(shape=[config.batch_size], fill_value=-1, dtype="int32") - self.inputs['need_block_len'] = paddle.full(shape=[1], fill_value=0, dtype="int32") - self.inputs['used_list_len'] = paddle.full(shape=[config.batch_size], fill_value=0, dtype="int32") - self.inputs['free_list'] = paddle.to_tensor(free_list, dtype="int32") - self.inputs['free_list_len'] = paddle.full(shape=[1], fill_value=pre_max_block_num * 0.25, dtype="int32") - - + self.inputs["step_seq_lens_encoder"] = paddle.full(shape=[config.batch_size, 1], fill_value=0, dtype="int32") + self.inputs["next_tokens"] = paddle.full(shape=[config.batch_size, 1], fill_value=-1, dtype="int64") + self.inputs["is_block_step"] = paddle.full(shape=[config.batch_size], fill_value=False, dtype="bool") + free_list = list(range(pre_max_block_num - 1, int(pre_max_block_num * 0.75) - 1, -1)) + self.inputs["encoder_block_lens"] = paddle.full(shape=[config.batch_size], fill_value=0, dtype="int32") + self.inputs["step_block_list"] = paddle.full(shape=[config.batch_size], fill_value=-1, dtype="int32") + self.inputs["step_lens"] = paddle.full(shape=[1], fill_value=0, dtype="int32") + self.inputs["recover_block_list"] = paddle.full(shape=[config.batch_size], fill_value=-1, dtype="int32") + self.inputs["recover_lens"] = paddle.full(shape=[1], fill_value=0, dtype="int32") + self.inputs["need_block_list"] = paddle.full(shape=[config.batch_size], fill_value=-1, dtype="int32") + self.inputs["need_block_len"] = paddle.full(shape=[1], fill_value=0, dtype="int32") + self.inputs["used_list_len"] = paddle.full(shape=[config.batch_size], fill_value=0, dtype="int32") + self.inputs["free_list"] = paddle.to_tensor(free_list, dtype="int32") + self.inputs["free_list_len"] = paddle.full(shape=[1], fill_value=pre_max_block_num * 0.25, dtype="int32") for i in range(self.num_layers): if self.config.use_cachekv_int8 == "dynamic": @@ -1124,15 +1173,21 @@ def _infer(self): self.predictor.run() def predict(self, input_texts: str | list[str]): + + s_time = time.time() self._preprocess(input_texts) real_bsz = len(input_texts) import copy + seq_lens_this_time = copy.deepcopy(self.inputs["seq_lens_this_time"][:real_bsz]) self.seq_lens_handle.share_external_data(seq_lens_this_time) - + logger.info(f"preprocess spend {time.time() - s_time}") + + s_time = time.time() while self.inputs["not_need_stop"]: self.predictor.run() + logger.info(f"running spend {time.time() - s_time}") # reset free_list for i in range(self.config.batch_size): @@ -1144,7 +1199,9 @@ def predict(self, input_texts: str | list[str]): def _preprocess(self, source): for i, text in enumerate(source): # print("text: ", text) - tokens = self.tokenizer(text, return_tensors="np", padding=False, max_length=(self.config.src_length - self.config.max_length)) + tokens = self.tokenizer( + text, return_tensors="np", padding=False, max_length=(self.config.src_length - self.config.max_length) + ) input_ids = tokens["input_ids"][0] length = len(input_ids) # print("input_ids: ", input_ids) @@ -1153,16 +1210,18 @@ def _preprocess(self, source): self.inputs["penalty_score"][i : i + 1] = self.config.repetition_penalty self.inputs["frequency_score"][i : i + 1] = 0.0 self.inputs["presence_score"][i : i + 1] = 0.0 - self.inputs['top_p'][i:i+1] = self.config.top_p - self.inputs['temperature'][i:i+1] = self.config.temperature + self.inputs["top_p"][i : i + 1] = self.config.top_p + self.inputs["temperature"][i : i + 1] = self.config.temperature self.inputs["seq_lens_this_time"][i : i + 1] = length - self.inputs['step_seq_lens_encoder'][i:i+1] = length + self.inputs["step_seq_lens_encoder"][i : i + 1] = length self.inputs["seq_lens_encoder"][i : i + 1] = length self.inputs["seq_lens_decoder"][i : i + 1] = 0 self.inputs["step_idx"][i : i + 1] = 0 self.inputs["stop_flags"][i : i + 1] = False reset_stop_value(self.inputs["not_need_stop"]) - need_block_nums = (length + self.config.max_length + self.pre_cache_length + self.block_size - 1) // self.block_size + need_block_nums = ( + length + self.config.max_length + self.pre_cache_length + self.block_size - 1 + ) // self.block_size # print("self.free_list", self.free_list) for bi in range(need_block_nums): bi_now = self.free_list.pop() @@ -1170,16 +1229,18 @@ def _preprocess(self, source): self.inputs["block_tables"][i : i + 1, bi] = bi_now # encoder_block_num = len(task['block_tables']) - self.inputs['encoder_block_lens'][i:i+1] = need_block_nums + self.inputs["encoder_block_lens"][i : i + 1] = need_block_nums + def get_ptq_multicards_num(directory): - count = 0 + count = 0 prefix = "act_scales_" - for filename in os.listdir(directory): - if filename.startswith(prefix): - count += 1 + for filename in os.listdir(directory): + if filename.startswith(prefix): + count += 1 return count + def create_predictor( predictor_args: PredictorArgument, model_args: ModelArgument, @@ -1248,7 +1309,7 @@ def create_predictor( config.model_name_or_path = "" config.use_cachekv_int8 = predictor_args.use_cachekv_int8 config.single_card_ptq = True - + if predictor_args.quant_type is not None and predictor_args.quant_type.startswith("weight_only_int"): weight_only_quant_bits = int(predictor_args.quant_type[-1]) config.weight_only_quant_bits = weight_only_quant_bits @@ -1257,7 +1318,7 @@ def create_predictor( if config.quantization_config.quant_type is not None and "a8w8" in config.quantization_config.quant_type: config.model_name_or_path = predictor_args.model_name_or_path config.quant_type = config.quantization_config.quant_type - + ptq_multicards_num = get_ptq_multicards_num(config.model_name_or_path) logger.info(f"PTQ from {ptq_multicards_num} cards, so we will not split") if ptq_multicards_num > 1: @@ -1284,7 +1345,9 @@ def create_predictor( LlamaForCausalLMInferenceModel as LlamaInferenceModel, ) model = LlamaInferenceModel.from_pretrained( - predictor_args.model_name_or_path, config=config, dtype=predictor_args.dtype, + predictor_args.model_name_or_path, + config=config, + dtype=predictor_args.dtype, tensor_parallel_degree=tensor_parallel_degree, tensor_parallel_rank=tensor_parallel_rank, ) @@ -1358,7 +1421,6 @@ def create_predictor( else: predictor = DygraphInferencePredictor(predictor_args, model=model, tokenizer=tokenizer) - elif predictor_args.mode == "static": config = AutoConfig.from_pretrained(predictor_args.model_name_or_path) if "llama" in config.architectures[0].lower(): @@ -1451,8 +1513,8 @@ def predict(): # source_texts = ["解释一下“温故而知新”", "你好,请问你是谁?"] source_texts = [] - data_file = open("humaneval_solution.json", 'r') - + data_file = open("humaneval_solution.json", "r") + dataset = [] for line in data_file.readlines(): dataset.append(json.loads(line)) @@ -1461,7 +1523,6 @@ def predict(): data = dataset[i % 164] source_texts.append(data["prompt"]) - target_texts = [""] * predictor_args.batch_size batch_source_texts = batchfy_text(source_texts, predictor_args.batch_size) @@ -1489,8 +1550,11 @@ def predict(): def benchmark(predictor, predictor_args, model_args): # Just construct a simple benchmark input. We pad input to the src_length. - test_texts = "" - benchmark_texts = [test_texts + "" * (predictor_args.src_length - predictor_args.max_length) for _ in range(predictor_args.batch_size)] + test_texts = "who are you" + benchmark_texts = [ + test_texts + "" * (predictor_args.src_length - predictor_args.max_length) + for _ in range(predictor_args.batch_size) + ] batch_benchmark_texts = batchfy_text(benchmark_texts, predictor_args.batch_size) print("***********Start Benchmark**********") @@ -1501,17 +1565,30 @@ def benchmark(predictor, predictor_args, model_args): print("***********Start Warmup**********") for i in range(warmup_time): print("warm up ", i) - for bs, batch_source_text in enumerate(batch_benchmark_texts): - outputs = predictor.predict(batch_source_text) + for _, batch_source_text in enumerate(batch_benchmark_texts): + predictor.predict(batch_source_text) + + from paddle import profiler + + # 创建性能分析器相关的代码 + def my_on_trace_ready(prof): # 定义回调函数,性能分析器结束采集数据时会被调用 + callback = profiler.export_chrome_tracing("./profiler_demo") # 创建导出性能数据到profiler_demo文件夹的回调函数 + callback(prof) # 执行该导出函数 + prof.summary(sorted_by=profiler.SortedKeys.GPUTotal) # 打印表单,按GPUTotal排序表单项 + + p = profiler.Profiler(scheduler=[3, 4], on_trace_ready=my_on_trace_ready, timer_only=False) # 初始化Profiler对象 print("***********Start Speed Test**********") start = time.perf_counter() output_tokens = 0 + p.start() for i in range(test_time): print("test ", i) - for bs, batch_source_text in enumerate(batch_benchmark_texts): - outputs = predictor.predict(batch_source_text) + for _, batch_source_text in enumerate(batch_benchmark_texts): + predictor.predict(batch_source_text) output_tokens += predictor_args.max_length * predictor_args.batch_size + p.step() + p.stop() end = time.perf_counter() print("Avg Elapse time is: ", (end - start) / test_time) print("Output tokens is: ", output_tokens) diff --git a/llm/run_static.sh b/llm/run_static.sh index fc05c5bc1775..a9349b1288d0 100644 --- a/llm/run_static.sh +++ b/llm/run_static.sh @@ -21,6 +21,7 @@ export FLAGS_control_flow_use_new_executor=1 export FLAGS_new_executor_serial_run=1 export FLAGS_allocator_strategy=naive_best_fit export FLAGS_fraction_of_gpu_memory_to_use=0.92 +export FLAGS_dynamic_static_unified_comm=0 model_dir=${1:-"checkpoints/llama65b_ptq_smooth_mp8"}