|
8 | 8 |
|
9 | 9 | from torch import nn
|
10 | 10 | from .replace_policy import replace_policies
|
| 11 | +from typing import Optional |
| 12 | +import torch |
| 13 | +from deepspeed import comm as dist |
| 14 | +from .layers import LinearAllreduce, LinearLayer |
| 15 | +from deepspeed.accelerator import get_accelerator |
| 16 | +from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw |
| 17 | + |
| 18 | + |
| 19 | +class ReplaceWithTensorSlicing: |
| 20 | + |
| 21 | + def __init__(self, mp_group=None, mp_size=1, out_dim=1, in_dim=0): |
| 22 | + if mp_group is not None: |
| 23 | + self.gpu_index = dist.get_rank(group=mp_group) |
| 24 | + else: |
| 25 | + self.gpu_index = 0 |
| 26 | + self.out_dim = out_dim |
| 27 | + self.in_dim = in_dim |
| 28 | + self.mp_size = mp_size |
| 29 | + |
| 30 | + def merge_assert(self, dim1, dim2): |
| 31 | + assert dim1 > dim2, \ |
| 32 | + 'Merging tensors is not allowed here! Please use deepspeed load_checkpoint\ |
| 33 | + for merging your checkpoints before replacing the transformer layer with\ |
| 34 | + inference-kernels' |
| 35 | + |
| 36 | + def strided_copy(self, |
| 37 | + dst: Optional[torch.Tensor], |
| 38 | + src: Optional[torch.Tensor], |
| 39 | + num_splits: int, |
| 40 | + int8: bool = False, |
| 41 | + allocate_tensor: bool = False): |
| 42 | + if src is None: |
| 43 | + return src |
| 44 | + src_shape = src.shape |
| 45 | + dst_shape = dst.shape |
| 46 | + |
| 47 | + outer_dim = 0 if int8 else -1 |
| 48 | + |
| 49 | + if allocate_tensor: |
| 50 | + dst = torch.empty_like(dst) |
| 51 | + |
| 52 | + src_split = torch.split(src.data, src.shape[outer_dim] // num_splits, dim=outer_dim) |
| 53 | + if (len(src_shape) == 2 and len(dst_shape) == 2): |
| 54 | + if src_shape[outer_dim] == dst_shape[self.out_dim]: |
| 55 | + dst = dst.reshape(-1).data.copy_(src.data.reshape(-1)).reshape(src.shape) |
| 56 | + dst = torch.nn.parameter.Parameter(dst, requires_grad=False) |
| 57 | + if hasattr(src, 'scale'): |
| 58 | + dst.scale = src.scale |
| 59 | + return dst |
| 60 | + self.merge_assert(src_shape[outer_dim], dst_shape[self.out_dim]) |
| 61 | + qkv_size = dst_shape[self.out_dim] // num_splits |
| 62 | + qkv_split = [torch.split(src_s, qkv_size, dim=outer_dim) for src_s in src_split] |
| 63 | + weight_split = [ |
| 64 | + torch.cat([qkv_s[i] for qkv_s in qkv_split], axis=outer_dim) for i in range(len(qkv_split[0])) |
| 65 | + ] |
| 66 | + dst = dst.reshape(-1).data.copy_(weight_split[self.gpu_index].contiguous().reshape(-1)).reshape( |
| 67 | + weight_split[self.gpu_index].shape) |
| 68 | + else: |
| 69 | + if src_shape[0] == dst_shape[0]: |
| 70 | + return torch.nn.parameter.Parameter(src) |
| 71 | + qkv_size = dst_shape[0] // num_splits |
| 72 | + qkv_split = [torch.split(src_s, qkv_size, dim=0) for src_s in src_split] |
| 73 | + bias_split = [torch.cat([qkv_s[i] for qkv_s in qkv_split], axis=0) for i in range(len(qkv_split[0]))] |
| 74 | + dst.data.copy_(bias_split[self.gpu_index].contiguous()) |
| 75 | + |
| 76 | + dst = torch.nn.parameter.Parameter(dst, requires_grad=False) |
| 77 | + if hasattr(src, 'scale'): |
| 78 | + dst.scale = src.scale |
| 79 | + return dst |
| 80 | + |
| 81 | + def copy(self, dst, src, int8=False, allocate_tensor=False): |
| 82 | + if src is None: |
| 83 | + return src |
| 84 | + assert not dst.data.is_meta # the torch.Tensor.copy_ method used below will silently fail on meta tensors |
| 85 | + if allocate_tensor: |
| 86 | + dst = torch.empty_like(dst) |
| 87 | + outer_dim = 0 if int8 else 1 |
| 88 | + inner_dim = 1 if int8 else 0 |
| 89 | + src_shape = src.shape |
| 90 | + dst_shape = dst.shape |
| 91 | + if (len(src_shape) == 2 and len(dst_shape) == 2): |
| 92 | + |
| 93 | + if src_shape[inner_dim] == dst_shape[self.in_dim] and src_shape[outer_dim] == dst_shape[self.out_dim]: |
| 94 | + dst = dst.reshape(-1).data.copy_(src.data.reshape(-1)).reshape(src.shape) |
| 95 | + else: |
| 96 | + if src_shape[inner_dim] != dst_shape[self.in_dim]: |
| 97 | + self.merge_assert(src_shape[inner_dim], dst_shape[self.in_dim]) |
| 98 | + dst.data.copy_(src[:, self.gpu_index * dst_shape[self.in_dim]: (self.gpu_index + 1) * dst_shape[self.in_dim]] if inner_dim == 1 else \ |
| 99 | + src[self.gpu_index * dst_shape[self.in_dim]: (self.gpu_index + 1) * dst_shape[self.in_dim], :]) |
| 100 | + else: |
| 101 | + self.merge_assert(src_shape[outer_dim], dst_shape[self.out_dim]) |
| 102 | + dst.data.copy_(src[:, self.gpu_index * dst_shape[self.out_dim]: (self.gpu_index + 1) * dst_shape[self.out_dim]] if outer_dim == 1 else \ |
| 103 | + src[self.gpu_index * dst_shape[self.out_dim]: (self.gpu_index + 1) * dst_shape[self.out_dim], :]) |
| 104 | + else: |
| 105 | + if src_shape[0] == dst_shape[0]: |
| 106 | + dst = src if src.dtype == dst.dtype else dst.data.copy_(src) |
| 107 | + else: |
| 108 | + dst.data.copy_(src[self.gpu_index * dst_shape[-1]:(self.gpu_index + 1) * dst_shape[-1]]) |
| 109 | + dst = torch.nn.parameter.Parameter(dst, requires_grad=False) |
| 110 | + if hasattr(src, 'scale'): |
| 111 | + dst.scale = src.scale |
| 112 | + return dst |
| 113 | + |
| 114 | + |
| 115 | +class Loading(): |
| 116 | + |
| 117 | + def load_buffer(module, state_dict, prefix): |
| 118 | + for name in module._buffers.keys(): |
| 119 | + if module._buffers[name].data.is_meta: |
| 120 | + module._buffers[name] = torch.nn.parameter.Parameter( |
| 121 | + data=torch.empty_like(module._buffers[name].data, device="cpu"), |
| 122 | + requires_grad=module._buffers[name].data.requires_grad) |
| 123 | + if prefix + name in state_dict.keys(): |
| 124 | + module._buffers[name].data.copy_(state_dict[prefix + name]) |
| 125 | + |
| 126 | + def load(module, state_dict, prefix, mp_group=None): |
| 127 | + mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group) |
| 128 | + if hasattr(module, 'weight'): |
| 129 | + if module.weight.data.is_meta: |
| 130 | + # meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here |
| 131 | + module.weight = torch.nn.parameter.Parameter(data=torch.empty_like(module.weight.data, device="cpu"), |
| 132 | + requires_grad=module.weight.data.requires_grad) |
| 133 | + if 'query_key_value' in prefix: |
| 134 | + module.weight = mp_replace.strided_copy(module.weight.data, |
| 135 | + state_dict[prefix + 'weight'], |
| 136 | + num_splits=3) |
| 137 | + else: |
| 138 | + module.weight = mp_replace.copy(module.weight.data, state_dict[prefix + 'weight']) |
| 139 | + else: |
| 140 | + if hasattr(module, 'norm') and hasattr(module.norm, 'weight'): |
| 141 | + if module.norm.weight.data.is_meta: |
| 142 | + # meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here |
| 143 | + module.norm.weight = torch.nn.parameter.Parameter( |
| 144 | + data=torch.empty_like(module.norm.weight.data, device="cpu"), |
| 145 | + requires_grad=module.norm.weight.data.requires_grad) |
| 146 | + module.norm.weight = mp_replace.copy(module.norm.weight.data, state_dict[prefix + 'weight']) |
| 147 | + |
| 148 | + if prefix + 'bias' in state_dict.keys(): |
| 149 | + if hasattr(module, 'bias'): |
| 150 | + if module.bias.data.is_meta: |
| 151 | + # meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here |
| 152 | + module.bias = torch.nn.parameter.Parameter(data=torch.empty_like(module.bias.data, device="cpu"), |
| 153 | + requires_grad=module.bias.data.requires_grad) |
| 154 | + module.bias = mp_replace.copy(module.bias, state_dict[prefix + 'bias']) |
| 155 | + else: |
| 156 | + if hasattr(module, 'norm') and hasattr(module.norm, 'bias'): |
| 157 | + if module.norm.bias.data.is_meta: |
| 158 | + # meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here |
| 159 | + module.norm.bias = torch.nn.parameter.Parameter( |
| 160 | + data=torch.empty_like(module.norm.bias.data, device="cpu"), |
| 161 | + requires_grad=module.norm.bias.data.requires_grad) |
| 162 | + module.norm.bias = mp_replace.copy(module.norm.bias, state_dict[prefix + 'bias']) |
11 | 163 |
|
12 | 164 |
|
13 | 165 | class AutoTP():
|
14 | 166 |
|
| 167 | + def __init__(self, module, all_reduce_linears, prefix, state_dict, linear_layer_setting, orig_layer_impl): |
| 168 | + self.module = module |
| 169 | + self.all_reduce_linears = all_reduce_linears |
| 170 | + self.prefix = prefix |
| 171 | + self.state_dict = state_dict |
| 172 | + |
| 173 | + self.mp_size = None |
| 174 | + self.mp_group = None |
| 175 | + self.linear_layer_setting = linear_layer_setting |
| 176 | + self.orig_layer_impl = orig_layer_impl |
| 177 | + self.linear_policies = None |
| 178 | + self.conv_linear_layer = False |
| 179 | + |
15 | 180 | def in_module_list(module, module_list):
|
16 | 181 | for item in module_list:
|
17 | 182 | if type(item).__name__ == type(module).__name__:
|
@@ -122,3 +287,153 @@ def tp_parser(model):
|
122 | 287 | assert len(policy_list), "AutoTP not supported for model. Please use kernel injection since container policy for model exists." \
|
123 | 288 | if AutoTP.kernel_supported(module_list) else "Not able to determine model policy automatically. Please provide policy."
|
124 | 289 | return policy_list
|
| 290 | + |
| 291 | + def set_tensor_parallel_config(self, mp_size, mp_group): |
| 292 | + self.mp_size = mp_size |
| 293 | + self.mp_group = mp_group |
| 294 | + |
| 295 | + def _replace(self, child, name, conv_linear_layer): |
| 296 | + if getattr(child, "replaced", False) == True: |
| 297 | + return |
| 298 | + weight_shape = child.weight.shape |
| 299 | + if name == 'attn.Wqkv' and self.module._get_name() == 'MPTBlock': |
| 300 | + # MPT block qkv weight's allocation is different from other models, it's [3,num_head,head_dim,hidden_size] |
| 301 | + # instead of [num_head,3,head_dim,hidden_size] |
| 302 | + new_weight = torch.empty(( |
| 303 | + weight_shape[0] // self.mp_size, |
| 304 | + weight_shape[1], |
| 305 | + ), |
| 306 | + device=child.weight.device, |
| 307 | + dtype=child.weight.dtype) |
| 308 | + reversed_dim = True |
| 309 | + mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group, out_dim=0) |
| 310 | + # todo: can we remove new tensor allocation if we use strided copy? |
| 311 | + mp_replace.strided_copy(new_weight, child.weight.data, num_splits=3, int8=reversed_dim) |
| 312 | + setattr(child, "replaced", True) |
| 313 | + return LinearLayer(weight=new_weight.to(get_accelerator().current_device_name()), bias=None) |
| 314 | + mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group) |
| 315 | + if name in self.all_reduce_linears: |
| 316 | + # if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size] |
| 317 | + # else [weight_shape[0], weight_shape[1] // mp_size] |
| 318 | + |
| 319 | + if self.conv_linear_layer: |
| 320 | + child.weight.data = child.weight.data.transpose(-1, -2).contiguous() |
| 321 | + data = child.weight.data.split( |
| 322 | + (weight_shape[0] if self.conv_linear_layer else weight_shape[1]) // self.mp_size, dim=1) |
| 323 | + data = data[mp_replace.gpu_index].to(get_accelerator().current_device_name()) |
| 324 | + |
| 325 | + setattr(child, "replaced", True) |
| 326 | + return LinearAllreduce(data, child.bias if child.bias is None else \ |
| 327 | + torch.nn.parameter.Parameter(child.bias.to(get_accelerator().current_device_name())), self.mp_group) |
| 328 | + else: |
| 329 | + |
| 330 | + # if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size] |
| 331 | + # else [weight_shape[0] // mp_size, weight_shape[1]] |
| 332 | + if self.conv_linear_layer: |
| 333 | + child.weight.data = child.weight.data.transpose(-1, -2).contiguous() |
| 334 | + |
| 335 | + if require_tp_fused_qkvw(name, self.mp_size): |
| 336 | + #for detecting fused type |
| 337 | + module_str = str(self.module).strip() |
| 338 | + #The copy is a regular copy, The shape of dst and src is the same |
| 339 | + data = prepare_tp_fused_qkvw(module_str, child.weight.data, self.mp_size, mp_replace.gpu_index) |
| 340 | + |
| 341 | + bias_data = None if child.bias is None else prepare_tp_fused_qkvw( |
| 342 | + module_str, child.bias.data, self.mp_size, mp_replace.gpu_index).to( |
| 343 | + get_accelerator().current_device_name()) |
| 344 | + else: |
| 345 | + data = child.weight.data.split((weight_shape[0]) // self.mp_size, |
| 346 | + dim=1 if self.conv_linear_layer else 0) |
| 347 | + data = data[mp_replace.gpu_index].to(get_accelerator().current_device_name()) |
| 348 | + |
| 349 | + if child.bias is not None: |
| 350 | + bias_data = child.bias.data.split( |
| 351 | + (weight_shape[1] if self.conv_linear_layer else weight_shape[0]) // self.mp_size, dim=0) |
| 352 | + bias_data = bias_data[mp_replace.gpu_index].to(get_accelerator().current_device_name()) |
| 353 | + else: |
| 354 | + bias_data = None |
| 355 | + |
| 356 | + setattr(child, "replaced", True) |
| 357 | + return LinearLayer(weight=data.to(get_accelerator().current_device_name()), bias=bias_data) |
| 358 | + |
| 359 | + def _slice_embedding(self, child, name, conv_linear_layer): |
| 360 | + if getattr(child, "replaced", False) == True: |
| 361 | + return |
| 362 | + mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group) |
| 363 | + |
| 364 | + if hasattr(child.weight, 'ds_tensor'): |
| 365 | + data = child.weight.ds_tensor.data.split(child.weight.shape[1] // self.mp_size, dim=1) |
| 366 | + else: |
| 367 | + data = child.weight.data.split(child.weight.shape[1] // self.mp_size, dim=1) |
| 368 | + data = data[mp_replace.gpu_index].to(get_accelerator().current_device_name()) |
| 369 | + |
| 370 | + new_embedding = nn.Embedding(child.weight.shape[0], child.weight.shape[1] // self.mp_size) |
| 371 | + new_embedding.weight.data.copy_(data) |
| 372 | + setattr(child, "replaced", True) |
| 373 | + return new_embedding |
| 374 | + |
| 375 | + def update_mp_params(self, child): |
| 376 | + if getattr(child, "replaced", False) == True: |
| 377 | + return |
| 378 | + for param in [ |
| 379 | + "n_heads", "inner_dim", "num_heads", "num_kv", "num_attention_heads", "num_attn_heads", |
| 380 | + "all_head_size", "embed_dim", "hidden_size", "num_key_value_heads" |
| 381 | + ]: |
| 382 | + if hasattr(child, param): |
| 383 | + param_val = getattr(child, param) |
| 384 | + assert param_val % self.mp_size == 0, f"{param} ({param_val}) must be divisible by mp_size ({self.mp_size})" |
| 385 | + setattr(child, param, param_val // self.mp_size) |
| 386 | + setattr(child, "replaced", True) |
| 387 | + |
| 388 | + def update_linear_polciies(self): |
| 389 | + self.conv_linear_layer = False |
| 390 | + if self.linear_layer_setting is not None: |
| 391 | + self.linear_policies = {self.linear_layer_setting[0]: self._replace} |
| 392 | + if len(self.linear_layer_setting) == 2: |
| 393 | + self.linear_policies.update({self.linear_layer_setting[1]: self._slice_embedding}) |
| 394 | + else: |
| 395 | + import transformers |
| 396 | + if self.orig_layer_impl is transformers.models.gpt2.modeling_gpt2.GPT2Block: |
| 397 | + try: |
| 398 | + self.conv_linear_layer = True |
| 399 | + self.linear_policies = {transformers.pytorch_utils.Conv1D: self._replace} |
| 400 | + except ImportError: |
| 401 | + self.linear_policies = {nn.Linear: self._replace} |
| 402 | + else: |
| 403 | + self.linear_policies = {nn.Linear: self._replace, nn.Embedding: self._slice_embedding} |
| 404 | + |
| 405 | + def _replace_module(self, r_module, prev_name='', prev_class_name=''): |
| 406 | + for name, child in r_module.named_children(): |
| 407 | + if prev_class_name == "": |
| 408 | + class_name = prev_name |
| 409 | + elif prev_name == "": |
| 410 | + class_name = prev_class_name |
| 411 | + else: |
| 412 | + class_name = prev_class_name + '.' + prev_name |
| 413 | + checking_key = self.prefix + '.' + class_name + '.' + name + '.' if class_name != "" else self.prefix + '.' + name + '.' |
| 414 | + if (child.__class__ in [nn.Linear, nn.Embedding, nn.LayerNorm] |
| 415 | + or child._get_name() in ["LPLayerNorm", "SharedEmbedding"]) and self.state_dict is not None: |
| 416 | + if any(checking_key in item for item in self.state_dict): |
| 417 | + Loading.load(child, self.state_dict, checking_key, self.mp_group) |
| 418 | + else: |
| 419 | + continue |
| 420 | + if len(child._buffers) != 0 and self.state_dict is not None: |
| 421 | + Loading.load_buffer(child, self.state_dict, checking_key) |
| 422 | + if child.__class__ in self.linear_policies: |
| 423 | + setattr(r_module, name, self.linear_policies[child.__class__](child, prev_name + '.' + name, |
| 424 | + self.conv_linear_layer)) |
| 425 | + elif any(isinstance(child, lp) for lp in self.linear_policies): |
| 426 | + # Added for falcon model support |
| 427 | + # Note: isinstance will account for class inheritance, child.__class__ does not |
| 428 | + key = None |
| 429 | + for lp in self.linear_policies: |
| 430 | + if isinstance(child, lp): |
| 431 | + key = lp |
| 432 | + break |
| 433 | + assert key is not None |
| 434 | + setattr(r_module, name, self.linear_policies[key](child, prev_name + '.' + name, |
| 435 | + self.conv_linear_layer)) |
| 436 | + else: |
| 437 | + self.update_mp_params(child) |
| 438 | + self._replace_module(child, name, class_name) |
| 439 | + return r_module |
0 commit comments