diff --git a/benchmarks/inference/bert-bench.py b/benchmarks/inference/bert-bench.py new file mode 100644 index 000000000000..e576d67f7d82 --- /dev/null +++ b/benchmarks/inference/bert-bench.py @@ -0,0 +1,88 @@ +import torch +import time +import deepspeed +import argparse +from transformers import pipeline + +parser = argparse.ArgumentParser() +parser.add_argument("--model", "-m", type=str, help="hf model name") +parser.add_argument("--deepspeed", action="store_true", help="use deepspeed inference") +parser.add_argument("--dtype", type=str, default="fp16", help="fp16 or fp32") +parser.add_argument("--max-tokens", type=int, default=50, help="max new tokens") +parser.add_argument("--local_rank", type=int, default=0, help="local rank") +parser.add_argument("--trials", type=int, default=30, help="number of trials") +parser.add_argument("--kernel-inject", action="store_true", help="inject kernels on") +parser.add_argument("--graphs", action="store_true", help="CUDA Graphs on") +args = parser.parse_args() + + +def print_latency(latency_set, title, warmup=3): + # trim warmup queries + latency_set = latency_set[warmup:] + count = len(latency_set) + if count > 0: + latency_set.sort() + n50 = (count - 1) * 0.5 + 1 + n90 = (count - 1) * 0.9 + 1 + n95 = (count - 1) * 0.95 + 1 + n99 = (count - 1) * 0.99 + 1 + n999 = (count - 1) * 0.999 + 1 + + avg = sum(latency_set) / count + p50 = latency_set[int(n50) - 1] + p90 = latency_set[int(n90) - 1] + p95 = latency_set[int(n95) - 1] + p99 = latency_set[int(n99) - 1] + p999 = latency_set[int(n999) - 1] + + print(f"====== latency stats {title} ======") + print("\tAvg Latency: {0:8.2f} ms".format(avg * 1000)) + print("\tP50 Latency: {0:8.2f} ms".format(p50 * 1000)) + print("\tP90 Latency: {0:8.2f} ms".format(p90 * 1000)) + print("\tP95 Latency: {0:8.2f} ms".format(p95 * 1000)) + print("\tP99 Latency: {0:8.2f} ms".format(p99 * 1000)) + print("\t999 Latency: {0:8.2f} ms".format(p999 * 1000)) + + +deepspeed.init_distributed("nccl") + +print(args.model, args.max_tokens, args.dtype) + +if args.dtype.lower() == "fp16": + dtype = torch.float16 +else: + dtype = torch.float32 + +pipe = pipeline("fill-mask", model=args.model, framework="pt", device=args.local_rank) + +if dtype == torch.half: + pipe.model.half() + +br = pipe("Hello I'm a [MASK] model") +if args.deepspeed: + pipe.model = deepspeed.init_inference(pipe.model, + dtype=dtype, + mp_size=1, + replace_with_kernel_inject=args.kernel_inject, + replace_method='auto', + enable_cuda_graph=args.graphs) + pipe.model.profile_model_time() + +responses = [] +times = [] +mtimes = [] +for i in range(args.trials): + torch.cuda.synchronize() + start = time.time() + r = pipe("Hello I'm a [MASK] model") + torch.cuda.synchronize() + end = time.time() + responses.append(r) + times.append((end - start)) + mtimes += pipe.model.model_times() + #print(f"{pipe.model.model_times()=}") + +print_latency(times, "e2e latency") +print_latency(mtimes, "model latency") + +print(responses[0:3]) diff --git a/benchmarks/inference/gpt-bench.py b/benchmarks/inference/gpt-bench.py index af1370abad1f..9d3905946e1b 100644 --- a/benchmarks/inference/gpt-bench.py +++ b/benchmarks/inference/gpt-bench.py @@ -34,6 +34,7 @@ def print_latency(latency_set, title, warmup=3): # trim warmup queries + latency_set = list(latency_set) latency_set = latency_set[warmup:] count = len(latency_set) if count > 0: @@ -94,9 +95,11 @@ def print_latency(latency_set, title, warmup=3): replace_method="auto", enable_cuda_graph=args.graphs, ) + pipe.model.profile_model_time() responses = [] times = [] +mtimes = [] for i in range(args.trials): torch.cuda.synchronize() start = time.time() @@ -104,10 +107,15 @@ def print_latency(latency_set, title, warmup=3): torch.cuda.synchronize() end = time.time() responses.append(r) - times.append((end - start) / (args.max_tokens - 3)) + times.append(end - start) # / (args.max_tokens - 3)) + mtimes.append(sum(pipe.model.model_times())) if args.local_rank == 0: - print_latency(times, "token latency") + print_latency(times, "(e2e) latency") + print_latency(mtimes, "(model-only) latency") + print_latency(map(lambda t: t / (args.max_tokens - 3), + times), + "(e2e) per token latency") print(f"RESPONSE 0:") print("-" * 30) print(responses[0][0]["generated_text"]) diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index 81566e7165c5..89a8d8288455 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -2,6 +2,7 @@ Copyright 2021 The Microsoft DeepSpeed Team ''' import torch +import time import os from deepspeed import comm as dist @@ -100,6 +101,8 @@ def __init__(self, self.cuda_graph_created = False self.checkpoint_engine = TorchCheckpointEngine() self._init_quantization_setting(quantization_setting) + self.model_profile_enabled = False + self._model_times = [] # This is a hack to remove the prepare_mask function on HF side for BLOOM architecture self.remove_mask_prepare_for_bloom() @@ -164,6 +167,12 @@ def __init__(self, if self.mp_world_size > 1: assert not self.enable_cuda_graph, "Cuda graph is not supported for model parallelism" + def profile_model_time(self): + if not self.model_profile_enabled and not self.enable_cuda_graph: + self.module.register_forward_pre_hook(self._pre_forward_hook) + self.module.register_forward_hook(self._post_forward_hook) + self.model_profile_enabled = True + def _get_model_config_generate(self, config): self.config = getattr(self.module, 'config', None) if config is None else config self.generate = getattr(self.module, 'generate', None) @@ -173,6 +182,15 @@ def remove_mask_prepare_for_bloom(self): if hasattr(self.module.transformer, '_prepare_attn_mask'): self.module.transformer._prepare_attn_mask = lambda attention_mask, *args, **kwargs: attention_mask + def _pre_forward_hook(self, module, *inputs, **kwargs): + torch.cuda.synchronize() + self._start = time.time() + + def _post_forward_hook(self, module, input, output): + torch.cuda.synchronize() + self._end = time.time() + self._model_times.append(self._end - self._start) + def _create_model_parallel_group(self): # Call the init process if InferenceEngine.inference_mp_group is None: @@ -364,7 +382,8 @@ def _apply_injection_policy(self, training_mp_size=training_mp_size, checkpoint_dict=checkpoint, save_mp_checkpoint_path=save_mp_checkpoint_path, - base_dir=base_dir) + base_dir=base_dir, + enable_cuda_graph=self.enable_cuda_graph) def _get_all_ckpt_names(self, checkpoints_path, tag): ckpt_file_pattern = self._get_ckpt_name(checkpoints_path, @@ -501,6 +520,18 @@ def _graph_replay(self, *inputs, **kwargs): self._cuda_graphs.replay() return self.static_output + def model_times(self): + assert self.model_profile_enabled, "model profiling is not enabled" + model_times = self._model_times + if self.enable_cuda_graph and len(self._model_times) == 0: + raise ValueError( + "Model times are empty and cuda graph is enabled. If " + "this is a GPT-style model this combo is not supported. If this is a " + "BERT-style model this is a bug, please report it. " + f"Model type is: {type(self.module)}") + self._model_times = [] + return model_times + def forward(self, *inputs, **kwargs): """Execute forward propagation @@ -508,6 +539,11 @@ def forward(self, *inputs, **kwargs): *inputs: Variable length input list **kwargs: variable length keyword arguments """ + start = None + if self.model_profile_enabled and self.enable_cuda_graph: + torch.cuda.synchronize() + start = time.time() + if self.enable_cuda_graph: if self.cuda_graph_created: outputs = self._graph_replay(*inputs, **kwargs) @@ -517,4 +553,9 @@ def forward(self, *inputs, **kwargs): else: outputs = self.module(*inputs, **kwargs) + if self.model_profile_enabled and self.enable_cuda_graph: + torch.cuda.synchronize() + duration = time.time() - start + self._model_times.append(duration) + return outputs diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index f42a973d8b21..ea0e13726316 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -215,7 +215,8 @@ def replace_transformer_layer(orig_layer_impl, moe_type='standard', checkpoint_dict=None, save_mp_checkpoint_path=None, - base_dir=""): + base_dir="", + enable_cuda_graph=False): """ Replace bert-style transformer layers with DeepSpeed's transformer layer Arguments: orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for, @@ -259,7 +260,9 @@ def replace_with_policy(child, inference=False, layer_id=0): policy = policy_cls(child, inference=inference) - + if not policy.cuda_graph_supported: + # policy says cuda graph is not supported raise an error if set + assert not enable_cuda_graph, "cuda graph is not supported with this model, please disable" if inference: hidden_size, num_attention_heads = policy.get_hidden_heads() assert num_attention_heads % mp_size == 0,\ @@ -687,7 +690,7 @@ def _replace(child, name, conv_linear_layer): weight_shape = child.weight.ds_shape else: weight_shape = child.weight.shape - if name in all_reduce_linears: + if isinstance(all_reduce_linears, dict) and name in all_reduce_linears: new_weight = torch.empty(( weight_shape[1] if conv_linear_layer else weight_shape[0], (weight_shape[0] if conv_linear_layer else weight_shape[1]) // diff --git a/deepspeed/module_inject/replace_policy.py b/deepspeed/module_inject/replace_policy.py index 0f8045dbd273..cb7c4818961a 100755 --- a/deepspeed/module_inject/replace_policy.py +++ b/deepspeed/module_inject/replace_policy.py @@ -24,7 +24,7 @@ def __init__( mlp_act_func_type=ActivationFuncType.GELU, # applies layer norm before attention if `pre_attn_norm` is set to True pre_attn_norm=True): - + self.cuda_graph_supported = False self.inference = inference self.linear_layer = linear_layer self.scale_attention = scale_attention @@ -69,6 +69,7 @@ class HFBertLayerPolicy(DSPolicy): def __init__(self, client_module, inference=False): super().__init__(inference, pre_attn_norm=False) self.client_module = client_module + self.cuda_graph_supported = True if HFBertLayerPolicy._orig_layer_class is None: try: diff --git a/tests/unit/inference/test_inference.py b/tests/unit/inference/test_inference.py index 04e9320accc8..5412fef264eb 100644 --- a/tests/unit/inference/test_inference.py +++ b/tests/unit/inference/test_inference.py @@ -132,6 +132,8 @@ def invalid_model_task_config(model_w_task, dtype, enable_cuda_graph): msg = f"Not enough GPU memory to run {model} with dtype {dtype}" elif ("bloom" in model) and (dtype != torch.half): msg = f"Bloom models only support half precision, cannot use dtype {dtype}" + elif ("bert" not in model.lower()) and enable_cuda_graph: + msg = "Non bert/roberta models do no support CUDA Graph" return msg