Skip to content

Commit 94c7233

Browse files
authored
Refactor autoTP inference for HE (#4040)
* Refactor autoTP inference for HE * Formatting * Move redundant functions to autotp * Remove self from loading class * formatting * Some gpt2 autotp path fixes * precommit
1 parent e31b404 commit 94c7233

File tree

2 files changed

+331
-304
lines changed

2 files changed

+331
-304
lines changed

deepspeed/module_inject/auto_tp.py

+315
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,175 @@
88

99
from torch import nn
1010
from .replace_policy import replace_policies
11+
from typing import Optional
12+
import torch
13+
from deepspeed import comm as dist
14+
from .layers import LinearAllreduce, LinearLayer
15+
from deepspeed.accelerator import get_accelerator
16+
from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw
17+
18+
19+
class ReplaceWithTensorSlicing:
20+
21+
def __init__(self, mp_group=None, mp_size=1, out_dim=1, in_dim=0):
22+
if mp_group is not None:
23+
self.gpu_index = dist.get_rank(group=mp_group)
24+
else:
25+
self.gpu_index = 0
26+
self.out_dim = out_dim
27+
self.in_dim = in_dim
28+
self.mp_size = mp_size
29+
30+
def merge_assert(self, dim1, dim2):
31+
assert dim1 > dim2, \
32+
'Merging tensors is not allowed here! Please use deepspeed load_checkpoint\
33+
for merging your checkpoints before replacing the transformer layer with\
34+
inference-kernels'
35+
36+
def strided_copy(self,
37+
dst: Optional[torch.Tensor],
38+
src: Optional[torch.Tensor],
39+
num_splits: int,
40+
int8: bool = False,
41+
allocate_tensor: bool = False):
42+
if src is None:
43+
return src
44+
src_shape = src.shape
45+
dst_shape = dst.shape
46+
47+
outer_dim = 0 if int8 else -1
48+
49+
if allocate_tensor:
50+
dst = torch.empty_like(dst)
51+
52+
src_split = torch.split(src.data, src.shape[outer_dim] // num_splits, dim=outer_dim)
53+
if (len(src_shape) == 2 and len(dst_shape) == 2):
54+
if src_shape[outer_dim] == dst_shape[self.out_dim]:
55+
dst = dst.reshape(-1).data.copy_(src.data.reshape(-1)).reshape(src.shape)
56+
dst = torch.nn.parameter.Parameter(dst, requires_grad=False)
57+
if hasattr(src, 'scale'):
58+
dst.scale = src.scale
59+
return dst
60+
self.merge_assert(src_shape[outer_dim], dst_shape[self.out_dim])
61+
qkv_size = dst_shape[self.out_dim] // num_splits
62+
qkv_split = [torch.split(src_s, qkv_size, dim=outer_dim) for src_s in src_split]
63+
weight_split = [
64+
torch.cat([qkv_s[i] for qkv_s in qkv_split], axis=outer_dim) for i in range(len(qkv_split[0]))
65+
]
66+
dst = dst.reshape(-1).data.copy_(weight_split[self.gpu_index].contiguous().reshape(-1)).reshape(
67+
weight_split[self.gpu_index].shape)
68+
else:
69+
if src_shape[0] == dst_shape[0]:
70+
return torch.nn.parameter.Parameter(src)
71+
qkv_size = dst_shape[0] // num_splits
72+
qkv_split = [torch.split(src_s, qkv_size, dim=0) for src_s in src_split]
73+
bias_split = [torch.cat([qkv_s[i] for qkv_s in qkv_split], axis=0) for i in range(len(qkv_split[0]))]
74+
dst.data.copy_(bias_split[self.gpu_index].contiguous())
75+
76+
dst = torch.nn.parameter.Parameter(dst, requires_grad=False)
77+
if hasattr(src, 'scale'):
78+
dst.scale = src.scale
79+
return dst
80+
81+
def copy(self, dst, src, int8=False, allocate_tensor=False):
82+
if src is None:
83+
return src
84+
assert not dst.data.is_meta # the torch.Tensor.copy_ method used below will silently fail on meta tensors
85+
if allocate_tensor:
86+
dst = torch.empty_like(dst)
87+
outer_dim = 0 if int8 else 1
88+
inner_dim = 1 if int8 else 0
89+
src_shape = src.shape
90+
dst_shape = dst.shape
91+
if (len(src_shape) == 2 and len(dst_shape) == 2):
92+
93+
if src_shape[inner_dim] == dst_shape[self.in_dim] and src_shape[outer_dim] == dst_shape[self.out_dim]:
94+
dst = dst.reshape(-1).data.copy_(src.data.reshape(-1)).reshape(src.shape)
95+
else:
96+
if src_shape[inner_dim] != dst_shape[self.in_dim]:
97+
self.merge_assert(src_shape[inner_dim], dst_shape[self.in_dim])
98+
dst.data.copy_(src[:, self.gpu_index * dst_shape[self.in_dim]: (self.gpu_index + 1) * dst_shape[self.in_dim]] if inner_dim == 1 else \
99+
src[self.gpu_index * dst_shape[self.in_dim]: (self.gpu_index + 1) * dst_shape[self.in_dim], :])
100+
else:
101+
self.merge_assert(src_shape[outer_dim], dst_shape[self.out_dim])
102+
dst.data.copy_(src[:, self.gpu_index * dst_shape[self.out_dim]: (self.gpu_index + 1) * dst_shape[self.out_dim]] if outer_dim == 1 else \
103+
src[self.gpu_index * dst_shape[self.out_dim]: (self.gpu_index + 1) * dst_shape[self.out_dim], :])
104+
else:
105+
if src_shape[0] == dst_shape[0]:
106+
dst = src if src.dtype == dst.dtype else dst.data.copy_(src)
107+
else:
108+
dst.data.copy_(src[self.gpu_index * dst_shape[-1]:(self.gpu_index + 1) * dst_shape[-1]])
109+
dst = torch.nn.parameter.Parameter(dst, requires_grad=False)
110+
if hasattr(src, 'scale'):
111+
dst.scale = src.scale
112+
return dst
113+
114+
115+
class Loading():
116+
117+
def load_buffer(module, state_dict, prefix):
118+
for name in module._buffers.keys():
119+
if module._buffers[name].data.is_meta:
120+
module._buffers[name] = torch.nn.parameter.Parameter(
121+
data=torch.empty_like(module._buffers[name].data, device="cpu"),
122+
requires_grad=module._buffers[name].data.requires_grad)
123+
if prefix + name in state_dict.keys():
124+
module._buffers[name].data.copy_(state_dict[prefix + name])
125+
126+
def load(module, state_dict, prefix, mp_group=None):
127+
mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group)
128+
if hasattr(module, 'weight'):
129+
if module.weight.data.is_meta:
130+
# meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
131+
module.weight = torch.nn.parameter.Parameter(data=torch.empty_like(module.weight.data, device="cpu"),
132+
requires_grad=module.weight.data.requires_grad)
133+
if 'query_key_value' in prefix:
134+
module.weight = mp_replace.strided_copy(module.weight.data,
135+
state_dict[prefix + 'weight'],
136+
num_splits=3)
137+
else:
138+
module.weight = mp_replace.copy(module.weight.data, state_dict[prefix + 'weight'])
139+
else:
140+
if hasattr(module, 'norm') and hasattr(module.norm, 'weight'):
141+
if module.norm.weight.data.is_meta:
142+
# meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
143+
module.norm.weight = torch.nn.parameter.Parameter(
144+
data=torch.empty_like(module.norm.weight.data, device="cpu"),
145+
requires_grad=module.norm.weight.data.requires_grad)
146+
module.norm.weight = mp_replace.copy(module.norm.weight.data, state_dict[prefix + 'weight'])
147+
148+
if prefix + 'bias' in state_dict.keys():
149+
if hasattr(module, 'bias'):
150+
if module.bias.data.is_meta:
151+
# meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
152+
module.bias = torch.nn.parameter.Parameter(data=torch.empty_like(module.bias.data, device="cpu"),
153+
requires_grad=module.bias.data.requires_grad)
154+
module.bias = mp_replace.copy(module.bias, state_dict[prefix + 'bias'])
155+
else:
156+
if hasattr(module, 'norm') and hasattr(module.norm, 'bias'):
157+
if module.norm.bias.data.is_meta:
158+
# meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
159+
module.norm.bias = torch.nn.parameter.Parameter(
160+
data=torch.empty_like(module.norm.bias.data, device="cpu"),
161+
requires_grad=module.norm.bias.data.requires_grad)
162+
module.norm.bias = mp_replace.copy(module.norm.bias, state_dict[prefix + 'bias'])
11163

12164

13165
class AutoTP():
14166

167+
def __init__(self, module, all_reduce_linears, prefix, state_dict, linear_layer_setting, orig_layer_impl):
168+
self.module = module
169+
self.all_reduce_linears = all_reduce_linears
170+
self.prefix = prefix
171+
self.state_dict = state_dict
172+
173+
self.mp_size = None
174+
self.mp_group = None
175+
self.linear_layer_setting = linear_layer_setting
176+
self.orig_layer_impl = orig_layer_impl
177+
self.linear_policies = None
178+
self.conv_linear_layer = False
179+
15180
def in_module_list(module, module_list):
16181
for item in module_list:
17182
if type(item).__name__ == type(module).__name__:
@@ -122,3 +287,153 @@ def tp_parser(model):
122287
assert len(policy_list), "AutoTP not supported for model. Please use kernel injection since container policy for model exists." \
123288
if AutoTP.kernel_supported(module_list) else "Not able to determine model policy automatically. Please provide policy."
124289
return policy_list
290+
291+
def set_tensor_parallel_config(self, mp_size, mp_group):
292+
self.mp_size = mp_size
293+
self.mp_group = mp_group
294+
295+
def _replace(self, child, name, conv_linear_layer):
296+
if getattr(child, "replaced", False) == True:
297+
return
298+
weight_shape = child.weight.shape
299+
if name == 'attn.Wqkv' and self.module._get_name() == 'MPTBlock':
300+
# MPT block qkv weight's allocation is different from other models, it's [3,num_head,head_dim,hidden_size]
301+
# instead of [num_head,3,head_dim,hidden_size]
302+
new_weight = torch.empty((
303+
weight_shape[0] // self.mp_size,
304+
weight_shape[1],
305+
),
306+
device=child.weight.device,
307+
dtype=child.weight.dtype)
308+
reversed_dim = True
309+
mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group, out_dim=0)
310+
# todo: can we remove new tensor allocation if we use strided copy?
311+
mp_replace.strided_copy(new_weight, child.weight.data, num_splits=3, int8=reversed_dim)
312+
setattr(child, "replaced", True)
313+
return LinearLayer(weight=new_weight.to(get_accelerator().current_device_name()), bias=None)
314+
mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group)
315+
if name in self.all_reduce_linears:
316+
# if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
317+
# else [weight_shape[0], weight_shape[1] // mp_size]
318+
319+
if self.conv_linear_layer:
320+
child.weight.data = child.weight.data.transpose(-1, -2).contiguous()
321+
data = child.weight.data.split(
322+
(weight_shape[0] if self.conv_linear_layer else weight_shape[1]) // self.mp_size, dim=1)
323+
data = data[mp_replace.gpu_index].to(get_accelerator().current_device_name())
324+
325+
setattr(child, "replaced", True)
326+
return LinearAllreduce(data, child.bias if child.bias is None else \
327+
torch.nn.parameter.Parameter(child.bias.to(get_accelerator().current_device_name())), self.mp_group)
328+
else:
329+
330+
# if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
331+
# else [weight_shape[0] // mp_size, weight_shape[1]]
332+
if self.conv_linear_layer:
333+
child.weight.data = child.weight.data.transpose(-1, -2).contiguous()
334+
335+
if require_tp_fused_qkvw(name, self.mp_size):
336+
#for detecting fused type
337+
module_str = str(self.module).strip()
338+
#The copy is a regular copy, The shape of dst and src is the same
339+
data = prepare_tp_fused_qkvw(module_str, child.weight.data, self.mp_size, mp_replace.gpu_index)
340+
341+
bias_data = None if child.bias is None else prepare_tp_fused_qkvw(
342+
module_str, child.bias.data, self.mp_size, mp_replace.gpu_index).to(
343+
get_accelerator().current_device_name())
344+
else:
345+
data = child.weight.data.split((weight_shape[0]) // self.mp_size,
346+
dim=1 if self.conv_linear_layer else 0)
347+
data = data[mp_replace.gpu_index].to(get_accelerator().current_device_name())
348+
349+
if child.bias is not None:
350+
bias_data = child.bias.data.split(
351+
(weight_shape[1] if self.conv_linear_layer else weight_shape[0]) // self.mp_size, dim=0)
352+
bias_data = bias_data[mp_replace.gpu_index].to(get_accelerator().current_device_name())
353+
else:
354+
bias_data = None
355+
356+
setattr(child, "replaced", True)
357+
return LinearLayer(weight=data.to(get_accelerator().current_device_name()), bias=bias_data)
358+
359+
def _slice_embedding(self, child, name, conv_linear_layer):
360+
if getattr(child, "replaced", False) == True:
361+
return
362+
mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group)
363+
364+
if hasattr(child.weight, 'ds_tensor'):
365+
data = child.weight.ds_tensor.data.split(child.weight.shape[1] // self.mp_size, dim=1)
366+
else:
367+
data = child.weight.data.split(child.weight.shape[1] // self.mp_size, dim=1)
368+
data = data[mp_replace.gpu_index].to(get_accelerator().current_device_name())
369+
370+
new_embedding = nn.Embedding(child.weight.shape[0], child.weight.shape[1] // self.mp_size)
371+
new_embedding.weight.data.copy_(data)
372+
setattr(child, "replaced", True)
373+
return new_embedding
374+
375+
def update_mp_params(self, child):
376+
if getattr(child, "replaced", False) == True:
377+
return
378+
for param in [
379+
"n_heads", "inner_dim", "num_heads", "num_kv", "num_attention_heads", "num_attn_heads",
380+
"all_head_size", "embed_dim", "hidden_size", "num_key_value_heads"
381+
]:
382+
if hasattr(child, param):
383+
param_val = getattr(child, param)
384+
assert param_val % self.mp_size == 0, f"{param} ({param_val}) must be divisible by mp_size ({self.mp_size})"
385+
setattr(child, param, param_val // self.mp_size)
386+
setattr(child, "replaced", True)
387+
388+
def update_linear_polciies(self):
389+
self.conv_linear_layer = False
390+
if self.linear_layer_setting is not None:
391+
self.linear_policies = {self.linear_layer_setting[0]: self._replace}
392+
if len(self.linear_layer_setting) == 2:
393+
self.linear_policies.update({self.linear_layer_setting[1]: self._slice_embedding})
394+
else:
395+
import transformers
396+
if self.orig_layer_impl is transformers.models.gpt2.modeling_gpt2.GPT2Block:
397+
try:
398+
self.conv_linear_layer = True
399+
self.linear_policies = {transformers.pytorch_utils.Conv1D: self._replace}
400+
except ImportError:
401+
self.linear_policies = {nn.Linear: self._replace}
402+
else:
403+
self.linear_policies = {nn.Linear: self._replace, nn.Embedding: self._slice_embedding}
404+
405+
def _replace_module(self, r_module, prev_name='', prev_class_name=''):
406+
for name, child in r_module.named_children():
407+
if prev_class_name == "":
408+
class_name = prev_name
409+
elif prev_name == "":
410+
class_name = prev_class_name
411+
else:
412+
class_name = prev_class_name + '.' + prev_name
413+
checking_key = self.prefix + '.' + class_name + '.' + name + '.' if class_name != "" else self.prefix + '.' + name + '.'
414+
if (child.__class__ in [nn.Linear, nn.Embedding, nn.LayerNorm]
415+
or child._get_name() in ["LPLayerNorm", "SharedEmbedding"]) and self.state_dict is not None:
416+
if any(checking_key in item for item in self.state_dict):
417+
Loading.load(child, self.state_dict, checking_key, self.mp_group)
418+
else:
419+
continue
420+
if len(child._buffers) != 0 and self.state_dict is not None:
421+
Loading.load_buffer(child, self.state_dict, checking_key)
422+
if child.__class__ in self.linear_policies:
423+
setattr(r_module, name, self.linear_policies[child.__class__](child, prev_name + '.' + name,
424+
self.conv_linear_layer))
425+
elif any(isinstance(child, lp) for lp in self.linear_policies):
426+
# Added for falcon model support
427+
# Note: isinstance will account for class inheritance, child.__class__ does not
428+
key = None
429+
for lp in self.linear_policies:
430+
if isinstance(child, lp):
431+
key = lp
432+
break
433+
assert key is not None
434+
setattr(r_module, name, self.linear_policies[key](child, prev_name + '.' + name,
435+
self.conv_linear_layer))
436+
else:
437+
self.update_mp_params(child)
438+
self._replace_module(child, name, class_name)
439+
return r_module

0 commit comments

Comments
 (0)