Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions benchmarks/inference/bert-bench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import torch
import time
import deepspeed
import argparse
from transformers import pipeline

parser = argparse.ArgumentParser()
parser.add_argument("--model", "-m", type=str, help="hf model name")
parser.add_argument("--deepspeed", action="store_true", help="use deepspeed inference")
parser.add_argument("--dtype", type=str, default="fp16", help="fp16 or fp32")
parser.add_argument("--max-tokens", type=int, default=50, help="max new tokens")
parser.add_argument("--local_rank", type=int, default=0, help="local rank")
parser.add_argument("--trials", type=int, default=30, help="number of trials")
parser.add_argument("--kernel-inject", action="store_true", help="inject kernels on")
parser.add_argument("--graphs", action="store_true", help="CUDA Graphs on")
args = parser.parse_args()


def print_latency(latency_set, title, warmup=3):
# trim warmup queries
latency_set = latency_set[warmup:]
count = len(latency_set)
if count > 0:
latency_set.sort()
n50 = (count - 1) * 0.5 + 1
n90 = (count - 1) * 0.9 + 1
n95 = (count - 1) * 0.95 + 1
n99 = (count - 1) * 0.99 + 1
n999 = (count - 1) * 0.999 + 1

avg = sum(latency_set) / count
p50 = latency_set[int(n50) - 1]
p90 = latency_set[int(n90) - 1]
p95 = latency_set[int(n95) - 1]
p99 = latency_set[int(n99) - 1]
p999 = latency_set[int(n999) - 1]

print(f"====== latency stats {title} ======")
print("\tAvg Latency: {0:8.2f} ms".format(avg * 1000))
print("\tP50 Latency: {0:8.2f} ms".format(p50 * 1000))
print("\tP90 Latency: {0:8.2f} ms".format(p90 * 1000))
print("\tP95 Latency: {0:8.2f} ms".format(p95 * 1000))
print("\tP99 Latency: {0:8.2f} ms".format(p99 * 1000))
print("\t999 Latency: {0:8.2f} ms".format(p999 * 1000))


deepspeed.init_distributed("nccl")

print(args.model, args.max_tokens, args.dtype)

if args.dtype.lower() == "fp16":
dtype = torch.float16
else:
dtype = torch.float32

pipe = pipeline("fill-mask", model=args.model, framework="pt", device=args.local_rank)

if dtype == torch.half:
pipe.model.half()

br = pipe("Hello I'm a [MASK] model")
if args.deepspeed:
pipe.model = deepspeed.init_inference(pipe.model,
dtype=dtype,
mp_size=1,
replace_with_kernel_inject=args.kernel_inject,
replace_method='auto',
enable_cuda_graph=args.graphs)
pipe.model.profile_model_time()

responses = []
times = []
mtimes = []
for i in range(args.trials):
torch.cuda.synchronize()
start = time.time()
r = pipe("Hello I'm a [MASK] model")
torch.cuda.synchronize()
end = time.time()
responses.append(r)
times.append((end - start))
mtimes += pipe.model.model_times()
#print(f"{pipe.model.model_times()=}")

print_latency(times, "e2e latency")
print_latency(mtimes, "model latency")

print(responses[0:3])
12 changes: 10 additions & 2 deletions benchmarks/inference/gpt-bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

def print_latency(latency_set, title, warmup=3):
# trim warmup queries
latency_set = list(latency_set)
latency_set = latency_set[warmup:]
count = len(latency_set)
if count > 0:
Expand Down Expand Up @@ -94,20 +95,27 @@ def print_latency(latency_set, title, warmup=3):
replace_method="auto",
enable_cuda_graph=args.graphs,
)
pipe.model.profile_model_time()

responses = []
times = []
mtimes = []
for i in range(args.trials):
torch.cuda.synchronize()
start = time.time()
r = pipe("DeepSpeed is", do_sample=False, max_new_tokens=args.max_tokens)
torch.cuda.synchronize()
end = time.time()
responses.append(r)
times.append((end - start) / (args.max_tokens - 3))
times.append(end - start) # / (args.max_tokens - 3))
mtimes.append(sum(pipe.model.model_times()))

if args.local_rank == 0:
print_latency(times, "token latency")
print_latency(times, "(e2e) latency")
print_latency(mtimes, "(model-only) latency")
print_latency(map(lambda t: t / (args.max_tokens - 3),
times),
"(e2e) per token latency")
print(f"RESPONSE 0:")
print("-" * 30)
print(responses[0][0]["generated_text"])
Expand Down
43 changes: 42 additions & 1 deletion deepspeed/inference/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Copyright 2021 The Microsoft DeepSpeed Team
'''
import torch
import time
import os

from deepspeed import comm as dist
Expand Down Expand Up @@ -100,6 +101,8 @@ def __init__(self,
self.cuda_graph_created = False
self.checkpoint_engine = TorchCheckpointEngine()
self._init_quantization_setting(quantization_setting)
self.model_profile_enabled = False
self._model_times = []

# This is a hack to remove the prepare_mask function on HF side for BLOOM architecture
self.remove_mask_prepare_for_bloom()
Expand Down Expand Up @@ -164,6 +167,12 @@ def __init__(self,
if self.mp_world_size > 1:
assert not self.enable_cuda_graph, "Cuda graph is not supported for model parallelism"

def profile_model_time(self):
if not self.model_profile_enabled and not self.enable_cuda_graph:
self.module.register_forward_pre_hook(self._pre_forward_hook)
self.module.register_forward_hook(self._post_forward_hook)
self.model_profile_enabled = True

def _get_model_config_generate(self, config):
self.config = getattr(self.module, 'config', None) if config is None else config
self.generate = getattr(self.module, 'generate', None)
Expand All @@ -173,6 +182,15 @@ def remove_mask_prepare_for_bloom(self):
if hasattr(self.module.transformer, '_prepare_attn_mask'):
self.module.transformer._prepare_attn_mask = lambda attention_mask, *args, **kwargs: attention_mask

def _pre_forward_hook(self, module, *inputs, **kwargs):
torch.cuda.synchronize()
self._start = time.time()

def _post_forward_hook(self, module, input, output):
torch.cuda.synchronize()
self._end = time.time()
self._model_times.append(self._end - self._start)

def _create_model_parallel_group(self):
# Call the init process
if InferenceEngine.inference_mp_group is None:
Expand Down Expand Up @@ -364,7 +382,8 @@ def _apply_injection_policy(self,
training_mp_size=training_mp_size,
checkpoint_dict=checkpoint,
save_mp_checkpoint_path=save_mp_checkpoint_path,
base_dir=base_dir)
base_dir=base_dir,
enable_cuda_graph=self.enable_cuda_graph)

def _get_all_ckpt_names(self, checkpoints_path, tag):
ckpt_file_pattern = self._get_ckpt_name(checkpoints_path,
Expand Down Expand Up @@ -501,13 +520,30 @@ def _graph_replay(self, *inputs, **kwargs):
self._cuda_graphs.replay()
return self.static_output

def model_times(self):
assert self.model_profile_enabled, "model profiling is not enabled"
model_times = self._model_times
if self.enable_cuda_graph and len(self._model_times) == 0:
raise ValueError(
"Model times are empty and cuda graph is enabled. If "
"this is a GPT-style model this combo is not supported. If this is a "
"BERT-style model this is a bug, please report it. "
f"Model type is: {type(self.module)}")
self._model_times = []
return model_times

def forward(self, *inputs, **kwargs):
"""Execute forward propagation

Arguments:
*inputs: Variable length input list
**kwargs: variable length keyword arguments
"""
start = None
if self.model_profile_enabled and self.enable_cuda_graph:
torch.cuda.synchronize()
start = time.time()

if self.enable_cuda_graph:
if self.cuda_graph_created:
outputs = self._graph_replay(*inputs, **kwargs)
Expand All @@ -517,4 +553,9 @@ def forward(self, *inputs, **kwargs):
else:
outputs = self.module(*inputs, **kwargs)

if self.model_profile_enabled and self.enable_cuda_graph:
torch.cuda.synchronize()
duration = time.time() - start
self._model_times.append(duration)

return outputs
9 changes: 6 additions & 3 deletions deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,8 @@ def replace_transformer_layer(orig_layer_impl,
moe_type='standard',
checkpoint_dict=None,
save_mp_checkpoint_path=None,
base_dir=""):
base_dir="",
enable_cuda_graph=False):
""" Replace bert-style transformer layers with DeepSpeed's transformer layer
Arguments:
orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for,
Expand Down Expand Up @@ -259,7 +260,9 @@ def replace_with_policy(child,
inference=False,
layer_id=0):
policy = policy_cls(child, inference=inference)

if not policy.cuda_graph_supported:
# policy says cuda graph is not supported raise an error if set
assert not enable_cuda_graph, "cuda graph is not supported with this model, please disable"
if inference:
hidden_size, num_attention_heads = policy.get_hidden_heads()
assert num_attention_heads % mp_size == 0,\
Expand Down Expand Up @@ -687,7 +690,7 @@ def _replace(child, name, conv_linear_layer):
weight_shape = child.weight.ds_shape
else:
weight_shape = child.weight.shape
if name in all_reduce_linears:
if isinstance(all_reduce_linears, dict) and name in all_reduce_linears:
new_weight = torch.empty((
weight_shape[1] if conv_linear_layer else weight_shape[0],
(weight_shape[0] if conv_linear_layer else weight_shape[1]) //
Expand Down
3 changes: 2 additions & 1 deletion deepspeed/module_inject/replace_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(
mlp_act_func_type=ActivationFuncType.GELU,
# applies layer norm before attention if `pre_attn_norm` is set to True
pre_attn_norm=True):

self.cuda_graph_supported = False
self.inference = inference
self.linear_layer = linear_layer
self.scale_attention = scale_attention
Expand Down Expand Up @@ -69,6 +69,7 @@ class HFBertLayerPolicy(DSPolicy):
def __init__(self, client_module, inference=False):
super().__init__(inference, pre_attn_norm=False)
self.client_module = client_module
self.cuda_graph_supported = True

if HFBertLayerPolicy._orig_layer_class is None:
try:
Expand Down