From cc0a7db2c60e70e5f43837ca731df07e67de629b Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Tue, 26 Jul 2022 03:39:38 +0500 Subject: [PATCH 01/15] Fix random token-generation issue + MP-checkpoint loading/saving --- deepspeed/__init__.py | 6 +- deepspeed/inference/engine.py | 27 +++++-- deepspeed/module_inject/load_checkpoint.py | 80 ++++++++++++------- deepspeed/module_inject/replace_module.py | 44 ++++++++-- .../inference/transformer_inference.py | 1 + deepspeed/runtime/state_dict_factory.py | 5 +- 6 files changed, 113 insertions(+), 50 deletions(-) diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index 7a18f98a49e8..296701e2e113 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -238,7 +238,8 @@ def init_inference(model, moe_experts=1, moe_type='standard', args=None, - enable_cuda_graph=False): + enable_cuda_graph=False, + save_mp_checkpoint=False): """Initialize the DeepSpeed InferenceEngine. Arguments: @@ -304,6 +305,7 @@ def init_inference(model, moe_experts, moe_type, args, - enable_cuda_graph) + enable_cuda_graph, + save_mp_checkpoint) return engine diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index ea29282da744..1979d66329a2 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -50,7 +50,8 @@ def __init__(self, moe_experts=1, moe_type='standard', config=None, - enable_cuda_graph=False): + enable_cuda_graph=False, + save_mp_checkpoint=False): """ Args: model: torch.nn.Module @@ -129,7 +130,8 @@ def __init__(self, moe_experts, moe_type, training_mp_size, - self.checkpoint if replace_with_kernel_inject else None) + self.checkpoint if replace_with_kernel_inject else None, + save_mp_checkpoint=save_mp_checkpoint) elif replace_method == 'auto': self._apply_injection_policy( return_tuple=return_tuple, @@ -138,12 +140,18 @@ def __init__(self, moe_experts=moe_experts, moe_type=moe_type, training_mp_size=training_mp_size, - checkpoint_dir=self.checkpoint if replace_with_kernel_inject else None) + checkpoint_dir=self.checkpoint if replace_with_kernel_inject else None, + save_mp_checkpoint=save_mp_checkpoint) device = torch.cuda.current_device() - logger.info(f"Place model to device: {device}") + # logger.info(f"Place model to device: {device}") self.module.to(device) + if self.mp_world_size > 1: + _rng_state = torch.cuda.get_rng_state().to(torch.cuda.current_device()) + dist.broadcast(_rng_state, 0) + torch.cuda.set_rng_state(_rng_state.cpu()) + if self.mp_world_size > 1: self.model_orig_fwd = self.module.forward self.module.forward = self.forward @@ -312,8 +320,9 @@ def _apply_injection_policy(self, moe_experts=1, moe_type='standard', training_mp_size=1, - checkpoint_dir=None): - checkpoint = SDLoaderFactory.get_sd_loader_json( + checkpoint_dir=None, + save_mp_checkpoint=False): + checkpoint, ckpt_type = SDLoaderFactory.get_sd_loader_json( checkpoint_dir) if checkpoint_dir is not None else None replace_transformer_layer(client_module, self.module, @@ -337,7 +346,9 @@ def _apply_injection_policy(self, moe_experts=moe_experts, moe_type=moe_type, training_mp_size=training_mp_size, - checkpoint=checkpoint) + checkpoint=checkpoint, + ckpt_type=ckpt_type, + save_mp_checkpoint=save_mp_checkpoint) def _get_all_ckpt_names(self, checkpoints_path, tag): ckpt_file_pattern = self._get_ckpt_name(checkpoints_path, @@ -377,7 +388,7 @@ def _load_checkpoint(self, load_dir, load_module_strict=True, tag=None): ckpt_list = self._get_all_ckpt_names(load_dir, tag) sd_loader = SDLoaderFactory.get_sd_loader(ckpt_list) else: - sd_loader = SDLoaderFactory.get_sd_loader_json(load_dir) + sd_loader, _ = SDLoaderFactory.get_sd_loader_json(load_dir) if type(sd_loader) is list: self.sd = torch.load(sd_loader[0], map_location='cpu') diff --git a/deepspeed/module_inject/load_checkpoint.py b/deepspeed/module_inject/load_checkpoint.py index f6722deb582b..dbd263e983cf 100644 --- a/deepspeed/module_inject/load_checkpoint.py +++ b/deepspeed/module_inject/load_checkpoint.py @@ -7,7 +7,7 @@ from .layers import LinearAllreduce, LinearLayer, Normalize, EmbeddingLayer -def load_model_with_checkpoint(r_module, sd, mp_replace): +def load_model_with_checkpoint(r_module, sd, mp_replace, ckpt_type): error_msgs = [] def transpose(data): @@ -32,33 +32,47 @@ def load(module, prefix): module.bias = mp_replace.copy(module.bias.data, sd[prefix + 'bias']) def load_transformer_layer(module, prefix): - module.norm_w.data.copy_(sd[prefix + 'input_layernorm.' + 'weight']) - module.norm_b.data.copy_(sd[prefix + 'input_layernorm.' + 'bias']) - module.attention.attn_qkvw = mp_replace.copy( - module.attention.attn_qkvw.data, - transpose(sd[prefix + 'self_attention.query_key_value.' + 'weight'])) - module.attention.attn_qkvb = mp_replace.copy( - module.attention.attn_qkvb.data, - sd[prefix + 'self_attention.query_key_value.' + 'bias']) - module.attention.attn_ow = mp_replace.copy( - module.attention.attn_ow.data, - transpose(sd[prefix + 'self_attention.dense.' + 'weight'])) - module.attention.attn_ob = mp_replace.copy( - module.attention.attn_ob.data, - sd[prefix + 'self_attention.dense.' + 'bias']) - module.mlp.attn_nw.data.copy_(sd[prefix + 'post_attention_layernorm.' + - 'weight']) - module.mlp.attn_nb.data.copy_(sd[prefix + 'post_attention_layernorm.' + 'bias']) - module.mlp.inter_w = mp_replace.copy( - module.mlp.inter_w.data, - transpose(sd[prefix + 'mlp.dense_h_to_4h.' + 'weight'])) - module.mlp.inter_b = mp_replace.copy(module.mlp.inter_b.data, - sd[prefix + 'mlp.dense_h_to_4h.' + 'bias']) - module.mlp.output_w = mp_replace.copy( - module.mlp.output_w.data, - transpose(sd[prefix + 'mlp.dense_4h_to_h.' + 'weight'])) - module.mlp.output_b = mp_replace.copy(module.mlp.output_b.data, - sd[prefix + 'mlp.dense_4h_to_h.' + 'bias']) + if ckpt_type == "mp": + + def load_parameters(module, prefix): + for n, p in module.named_parameters(): + if len(n.split('.')) == 1: + setattr(module, n, mp_replace.copy(p, sd[prefix + n])) + + load_parameters(module, prefix) + for n, child in module.named_children(): + load_parameters(child, prefix + n + '.') + else: + module.norm_w.data.copy_(sd[prefix + 'input_layernorm.' + 'weight']) + module.norm_b.data.copy_(sd[prefix + 'input_layernorm.' + 'bias']) + module.attention.attn_qkvw = mp_replace.copy( + module.attention.attn_qkvw.data, + transpose(sd[prefix + 'self_attention.query_key_value.' + 'weight'])) + module.attention.attn_qkvb = mp_replace.copy( + module.attention.attn_qkvb.data, + sd[prefix + 'self_attention.query_key_value.' + 'bias']) + module.attention.attn_ow = mp_replace.copy( + module.attention.attn_ow.data, + transpose(sd[prefix + 'self_attention.dense.' + 'weight'])) + module.attention.attn_ob = mp_replace.copy( + module.attention.attn_ob.data, + sd[prefix + 'self_attention.dense.' + 'bias']) + module.mlp.attn_nw.data.copy_(sd[prefix + 'post_attention_layernorm.' + + 'weight']) + module.mlp.attn_nb.data.copy_(sd[prefix + 'post_attention_layernorm.' + + 'bias']) + module.mlp.inter_w = mp_replace.copy( + module.mlp.inter_w.data, + transpose(sd[prefix + 'mlp.dense_h_to_4h.' + 'weight'])) + module.mlp.inter_b = mp_replace.copy( + module.mlp.inter_b.data, + sd[prefix + 'mlp.dense_h_to_4h.' + 'bias']) + module.mlp.output_w = mp_replace.copy( + module.mlp.output_w.data, + transpose(sd[prefix + 'mlp.dense_4h_to_h.' + 'weight'])) + module.mlp.output_b = mp_replace.copy( + module.mlp.output_b.data, + sd[prefix + 'mlp.dense_4h_to_h.' + 'bias']) layer_policies = { nn.Linear: load, @@ -98,6 +112,9 @@ def load_module_recursive(module, prefix='', level=0): dtype=child.weight.dtype, eps=child.eps) setattr(module, name, child) + elif child.__class__ is nn.Linear: + child = LinearLayer(weight=child.weight, bias=child.bias) + setattr(module, name, child) else: ds_id = None if hasattr(child.weight, 'ds_id'): @@ -110,9 +127,10 @@ def load_module_recursive(module, prefix='', level=0): layer_policies[child.__class__](child, prefix + name + '.') else: - load_module_recursive(child, - prefix if level == 0 else prefix + name + '.', - level + 1) + load_module_recursive( + child, + prefix if level == 0 and ckpt_type == 'pp' else prefix + name + '.', + level + 1) load_module_recursive(r_module) diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index fe83b7b3b4cf..0cc0d8e61913 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -15,6 +15,7 @@ from ..runtime.zero import GatheredParameters from .layers import LinearAllreduce, LinearLayer, Normalize, EmbeddingLayer from .load_checkpoint import load_model_with_checkpoint +import time class ReplaceWithTensorSlicing: @@ -148,7 +149,9 @@ def replace_transformer_layer(orig_layer_impl, moe=False, moe_experts=1, moe_type='standard', - checkpoint=None): + checkpoint=None, + ckpt_type='pp', + save_mp_checkpoint=False): """ Replace bert-style transformer layers with DeepSpeed's transformer layer Arguments: orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for, @@ -765,14 +768,39 @@ def replace_fn(child, _policy, layer_id=0): replace_fn=replace_fn, _replace_policy=policy) + start_time = time.time() if checkpoint is not None: - pbar = tqdm.tqdm(total=len(checkpoint), - desc=f"Loading {len(checkpoint)} checkpoint shards") - for i in range(len(checkpoint)): - if not deepspeed.comm.is_initialized() or deepspeed.comm.get_rank() == 0: - pbar.update(1) - sd = torch.load(checkpoint[i], map_location='cpu') - load_model_with_checkpoint(replaced_module, sd, mp_replace) + rank = dist.get_rank() if dist.is_initialized() else 0 + + if ckpt_type == 'pp': + pbar = tqdm.tqdm(total=len(checkpoint), + desc=f"Loading {len(checkpoint)} checkpoint shards") + for i in range(len(checkpoint)): + if not deepspeed.comm.is_initialized() or deepspeed.comm.get_rank() == 0: + pbar.update(1) + sd = torch.load(checkpoint[i], map_location='cpu') + load_model_with_checkpoint(replaced_module, sd, mp_replace, ckpt_type) + else: + world_size = dist.get_world_size() if dist.is_initialized() else 1 + num_checkpoints = len(checkpoint) // world_size + assert world_size >= len(checkpoint),\ + "Currently, merging checkpoints is not supported (when world_size is smaller than #checkpoints)!" + checkpoint_stride = world_size // len(checkpoint) + pbar = tqdm.tqdm(total=num_checkpoints, + desc=f"Loading {num_checkpoints} checkpoint shards") + for i in range(num_checkpoints): + if not deepspeed.comm.is_initialized() or deepspeed.comm.get_rank() == 0: + pbar.update(1) + sd = torch.load(checkpoint[i * world_size + (rank // checkpoint_stride)], + map_location='cpu') + load_model_with_checkpoint(replaced_module, sd, mp_replace, ckpt_type) + + if save_mp_checkpoint: + if dist.is_initialized(): + dist.barrier() + torch.save(replaced_module.state_dict(), f'/tmp/bloom-mp_0{rank}.pt') + + print(f"checkpoint loading time at rank {rank}: {time.time()-start_time} sec") return replaced_module diff --git a/deepspeed/ops/transformer/inference/transformer_inference.py b/deepspeed/ops/transformer/inference/transformer_inference.py index 342f4e1b595f..5cee1b50aa0a 100755 --- a/deepspeed/ops/transformer/inference/transformer_inference.py +++ b/deepspeed/ops/transformer/inference/transformer_inference.py @@ -192,6 +192,7 @@ def split_tensor_along_last_dim(tensor, return tensor_list def backup_attention(mixed_x_layer, layer_past, alibi, input_mask, norm_factor): + alibi = alibi.to(torch.cuda.current_device()) head_dim = hidden_size_per_partition // num_attention_heads_per_partition new_tensor_shape = mixed_x_layer.size()[:-1] + ( num_attention_heads_per_partition, diff --git a/deepspeed/runtime/state_dict_factory.py b/deepspeed/runtime/state_dict_factory.py index 35ccb8d5fd02..17cf1ff73906 100755 --- a/deepspeed/runtime/state_dict_factory.py +++ b/deepspeed/runtime/state_dict_factory.py @@ -22,8 +22,11 @@ def get_sd_loader_json(json_file): sd_type = data['type'] ckpt_list = data['checkpoints'] version = data['version'] + ckpt_type = 'pp' + if 'parallelization' in data: + ckpt_type = data['parallelization'] if 'BLOOM' in sd_type or 'Bloom' in sd_type: - return ckpt_list + return ckpt_list, ckpt_type return SDLoaderFactory.get_sd_loader(ckpt_list, sd_type, version) @staticmethod From b7085ea982ad4b28a9661afb5d1e2f38f85b2d6c Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Tue, 26 Jul 2022 03:58:37 +0500 Subject: [PATCH 02/15] small fix --- deepspeed/inference/engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index 825b9ec83877..361664a06885 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -326,7 +326,7 @@ def _apply_injection_policy(self, checkpoint_dir=None, save_mp_checkpoint=False): checkpoint, ckpt_type = SDLoaderFactory.get_sd_loader_json( - checkpoint_dir) if checkpoint_dir is not None else None + checkpoint_dir) if checkpoint_dir is not None else (None, None) replace_transformer_layer(client_module, self.module, triangular_masking=self.triangular_masking, From fa6b6ae83649bce509908ee7b33d55bd285047a4 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Tue, 26 Jul 2022 05:11:29 +0500 Subject: [PATCH 03/15] get the path for saving mp-checkpoints --- deepspeed/__init__.py | 4 ++-- deepspeed/inference/engine.py | 10 +++++----- deepspeed/module_inject/load_checkpoint.py | 2 +- deepspeed/module_inject/replace_module.py | 7 ++++--- 4 files changed, 12 insertions(+), 11 deletions(-) diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index 296701e2e113..655d7a96320c 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -239,7 +239,7 @@ def init_inference(model, moe_type='standard', args=None, enable_cuda_graph=False, - save_mp_checkpoint=False): + save_mp_checkpoint_path=False): """Initialize the DeepSpeed InferenceEngine. Arguments: @@ -306,6 +306,6 @@ def init_inference(model, moe_type, args, enable_cuda_graph, - save_mp_checkpoint) + save_mp_checkpoint_path) return engine diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index 361664a06885..d8c7a9c07658 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -52,7 +52,7 @@ def __init__(self, moe_type='standard', config=None, enable_cuda_graph=False, - save_mp_checkpoint=False): + save_mp_checkpoint_path=False): """ Args: model: torch.nn.Module @@ -132,7 +132,7 @@ def __init__(self, moe_type, training_mp_size, self.checkpoint if replace_with_kernel_inject else None, - save_mp_checkpoint=save_mp_checkpoint) + save_mp_checkpoint_path=save_mp_checkpoint_path) elif replace_method == 'auto': self._apply_injection_policy( return_tuple=return_tuple, @@ -142,7 +142,7 @@ def __init__(self, moe_type=moe_type, training_mp_size=training_mp_size, checkpoint_dir=self.checkpoint if replace_with_kernel_inject else None, - save_mp_checkpoint=save_mp_checkpoint) + save_mp_checkpoint_path=save_mp_checkpoint_path) device = torch.cuda.current_device() # logger.info(f"Place model to device: {device}") @@ -324,7 +324,7 @@ def _apply_injection_policy(self, moe_type='standard', training_mp_size=1, checkpoint_dir=None, - save_mp_checkpoint=False): + save_mp_checkpoint_path=False): checkpoint, ckpt_type = SDLoaderFactory.get_sd_loader_json( checkpoint_dir) if checkpoint_dir is not None else (None, None) replace_transformer_layer(client_module, @@ -351,7 +351,7 @@ def _apply_injection_policy(self, training_mp_size=training_mp_size, checkpoint=checkpoint, ckpt_type=ckpt_type, - save_mp_checkpoint=save_mp_checkpoint) + save_mp_checkpoint_path=save_mp_checkpoint_path) def _get_all_ckpt_names(self, checkpoints_path, tag): ckpt_file_pattern = self._get_ckpt_name(checkpoints_path, diff --git a/deepspeed/module_inject/load_checkpoint.py b/deepspeed/module_inject/load_checkpoint.py index dbd263e983cf..fb0235fe3f3f 100644 --- a/deepspeed/module_inject/load_checkpoint.py +++ b/deepspeed/module_inject/load_checkpoint.py @@ -32,7 +32,7 @@ def load(module, prefix): module.bias = mp_replace.copy(module.bias.data, sd[prefix + 'bias']) def load_transformer_layer(module, prefix): - if ckpt_type == "mp": + if ckpt_type == "tp": def load_parameters(module, prefix): for n, p in module.named_parameters(): diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 0cc0d8e61913..7bc4339adce2 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -151,7 +151,7 @@ def replace_transformer_layer(orig_layer_impl, moe_type='standard', checkpoint=None, ckpt_type='pp', - save_mp_checkpoint=False): + save_mp_checkpoint_path=None): """ Replace bert-style transformer layers with DeepSpeed's transformer layer Arguments: orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for, @@ -795,10 +795,11 @@ def replace_fn(child, _policy, layer_id=0): map_location='cpu') load_model_with_checkpoint(replaced_module, sd, mp_replace, ckpt_type) - if save_mp_checkpoint: + if save_mp_checkpoint_path is not None: if dist.is_initialized(): dist.barrier() - torch.save(replaced_module.state_dict(), f'/tmp/bloom-mp_0{rank}.pt') + torch.save(replaced_module.state_dict(), + f'{save_mp_checkpoint_path}/bloom-tp_0{rank}.pt') print(f"checkpoint loading time at rank {rank}: {time.time()-start_time} sec") return replaced_module From 1ae789685887a9a05087ce5b913bd949757e6b97 Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Mon, 25 Jul 2022 22:27:24 -0700 Subject: [PATCH 04/15] bug fix + formatting --- deepspeed/inference/engine.py | 2 -- deepspeed/module_inject/replace_module.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index b5841dab2916..b21cb8a8e3a9 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -13,7 +13,6 @@ from ..runtime.state_dict_factory import SDLoaderFactory from ..runtime.weight_quantizer import WeightQuantization from ..module_inject.replace_module import replace_transformer_layer -from ..utils import logger from ..comm.comm import init_distributed from ..pipe import PipelineModule from ..moe.utils import has_moe_layers @@ -143,7 +142,6 @@ def __init__(self, save_mp_checkpoint_path=save_mp_checkpoint_path) device = torch.cuda.current_device() - # logger.info(f"Place model to device: {device}") self.module.to(device) if self.mp_world_size > 1: diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index fd0adf31ec11..3b962b436f40 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -768,8 +768,8 @@ def replace_fn(child, _policy, layer_id=0): _replace_policy=policy) start_time = time.time() + rank = dist.get_rank() if dist.is_initialized() else 0 if checkpoint is not None: - rank = dist.get_rank() if dist.is_initialized() else 0 if ckpt_type == 'pp': pbar = tqdm.tqdm(total=len(checkpoint), From 070f02270de1dd7c9dfb041f10b410dfb9f30385 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Tue, 26 Jul 2022 22:33:57 +0500 Subject: [PATCH 05/15] fix save_checkpoint path --- deepspeed/__init__.py | 2 +- deepspeed/inference/engine.py | 3 +-- deepspeed/module_inject/replace_module.py | 6 +++--- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index 655d7a96320c..50049a2a1996 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -239,7 +239,7 @@ def init_inference(model, moe_type='standard', args=None, enable_cuda_graph=False, - save_mp_checkpoint_path=False): + save_mp_checkpoint_path=None): """Initialize the DeepSpeed InferenceEngine. Arguments: diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index b5841dab2916..399fab3576c5 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -13,7 +13,6 @@ from ..runtime.state_dict_factory import SDLoaderFactory from ..runtime.weight_quantizer import WeightQuantization from ..module_inject.replace_module import replace_transformer_layer -from ..utils import logger from ..comm.comm import init_distributed from ..pipe import PipelineModule from ..moe.utils import has_moe_layers @@ -50,7 +49,7 @@ def __init__(self, moe_type='standard', config=None, enable_cuda_graph=False, - save_mp_checkpoint_path=False): + save_mp_checkpoint_path=None): """ Args: model: torch.nn.Module diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index fd0adf31ec11..0651901ed716 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -793,14 +793,14 @@ def replace_fn(child, _policy, layer_id=0): sd = torch.load(checkpoint[i * world_size + (rank // checkpoint_stride)], map_location='cpu') load_model_with_checkpoint(replaced_module, sd, mp_replace, ckpt_type) + print(f"checkpoint loading time at rank {rank}: {time.time()-start_time} sec") if save_mp_checkpoint_path is not None: if dist.is_initialized(): dist.barrier() - torch.save(replaced_module.state_dict(), - f'{save_mp_checkpoint_path}/bloom-tp_0{rank}.pt') + torch.save(replaced_module.transformer.state_dict(), + f'{save_mp_checkpoint_path}/bloom-tp_{rank:0>2d}.pt') - print(f"checkpoint loading time at rank {rank}: {time.time()-start_time} sec") return replaced_module From d794e6c89cbb65e1351074c21f08e0f1f4a331f5 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Thu, 28 Jul 2022 06:26:04 +0500 Subject: [PATCH 06/15] Modify checkpoint saving to include the config json used during loading --- deepspeed/inference/engine.py | 10 +-- deepspeed/module_inject/load_checkpoint.py | 34 +++++++++- deepspeed/module_inject/replace_module.py | 76 +++++++++++++++++++--- deepspeed/module_inject/replace_policy.py | 5 ++ deepspeed/runtime/state_dict_factory.py | 5 +- 5 files changed, 113 insertions(+), 17 deletions(-) diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index 21d5a4a5b4ba..c8990da31c50 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -321,8 +321,8 @@ def _apply_injection_policy(self, training_mp_size=1, checkpoint_dir=None, save_mp_checkpoint_path=False): - checkpoint, ckpt_type = SDLoaderFactory.get_sd_loader_json( - checkpoint_dir) if checkpoint_dir is not None else (None, None) + checkpoint, ckpt_type, ckpt_name, ckpt_mp_size = SDLoaderFactory.get_sd_loader_json( + checkpoint_dir) if checkpoint_dir is not None else (None, None, None) replace_transformer_layer(client_module, self.module, triangular_masking=self.triangular_masking, @@ -347,7 +347,9 @@ def _apply_injection_policy(self, training_mp_size=training_mp_size, checkpoint=checkpoint, ckpt_type=ckpt_type, - save_mp_checkpoint_path=save_mp_checkpoint_path) + save_mp_checkpoint_path=save_mp_checkpoint_path, + ckpt_name=ckpt_name, + ckpt_mp_size=ckpt_mp_size) def _get_all_ckpt_names(self, checkpoints_path, tag): ckpt_file_pattern = self._get_ckpt_name(checkpoints_path, @@ -387,7 +389,7 @@ def _load_checkpoint(self, load_dir, load_module_strict=True, tag=None): ckpt_list = self._get_all_ckpt_names(load_dir, tag) sd_loader = SDLoaderFactory.get_sd_loader(ckpt_list) else: - sd_loader, _ = SDLoaderFactory.get_sd_loader_json(load_dir) + sd_loader, _, _, _ = SDLoaderFactory.get_sd_loader_json(load_dir) if type(sd_loader) is list: self.sd = torch.load(sd_loader[0], map_location='cpu') diff --git a/deepspeed/module_inject/load_checkpoint.py b/deepspeed/module_inject/load_checkpoint.py index fc8100772669..e0f44675dfd7 100644 --- a/deepspeed/module_inject/load_checkpoint.py +++ b/deepspeed/module_inject/load_checkpoint.py @@ -2,9 +2,10 @@ import deepspeed.ops.transformer as transformer_inference from ..runtime.zero import GatheredParameters from .layers import LinearLayer, Normalize, EmbeddingLayer +import torch -def load_model_with_checkpoint(r_module, sd, mp_replace, ckpt_type): +def load_model_with_checkpoint(r_module, sd, mp_replace, ckpt_type, rank=0): error_msgs = [] def transpose(data): @@ -34,7 +35,36 @@ def load_transformer_layer(module, prefix): def load_parameters(module, prefix): for n, p in module.named_parameters(): if len(n.split('.')) == 1: - setattr(module, n, mp_replace.copy(p, sd[prefix + n])) + src_shape = sd[prefix + n].shape + dst_shape = p.shape + + if (len(src_shape) == 2 and len(dst_shape) == 2): + if src_shape[0] == dst_shape[0] and src_shape[ + 1] == dst_shape[1]: + p.data.copy_(sd[prefix + n]) + else: + if src_shape[0] != dst_shape[0]: + weight_split = torch.split( + sd[prefix + n], + dst_shape[0], + dim=0)[rank].to( + torch.cuda.current_device()).contiguous() + else: + weight_split = torch.split( + sd[prefix + n], + dst_shape[1], + dim=1)[rank].to( + torch.cuda.current_device()).contiguous() + p.data.copy_(weight_split.contiguous()) + else: + if src_shape[0] == dst_shape[0]: + p.data.copy_(sd[prefix + n]) + else: + bias_split = torch.split( + sd[prefix + n], + dst_shape[-1])[rank].to( + torch.cuda.current_device()).contiguous() + p.data.copy_(bias_split) load_parameters(module, prefix) for n, child in module.named_children(): diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 350445a94264..5a7448bfe7f3 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -118,6 +118,21 @@ def copy(self, dst, src): return torch.nn.parameter.Parameter(dst, requires_grad=False) +def get_transformer_name(replaced_module): + from .replace_policy import supported_models + from torch.nn import ModuleList + transformer_name = '' + for n, c in replaced_module.named_children(): + if c.__class__ in supported_models: + transformer_name += n + '.' + for name, child in c.named_children(): + if child.__class__ is ModuleList: + transformer_name += name + break + break + return transformer_name + + def replace_transformer_layer(orig_layer_impl, model, policy=None, @@ -147,7 +162,9 @@ def replace_transformer_layer(orig_layer_impl, moe_type='standard', checkpoint=None, ckpt_type='pp', - save_mp_checkpoint_path=None): + save_mp_checkpoint_path=None, + ckpt_name=None, + ckpt_mp_size=1): """ Replace bert-style transformer layers with DeepSpeed's transformer layer Arguments: orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for, @@ -768,8 +785,9 @@ def replace_fn(child, _policy, layer_id=0): _replace_policy=policy) start_time = time.time() - rank = dist.get_rank() if dist.is_initialized() else 0 if checkpoint is not None: + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 if ckpt_type == 'pp': pbar = tqdm.tqdm(total=len(checkpoint), @@ -780,26 +798,64 @@ def replace_fn(child, _policy, layer_id=0): sd = torch.load(checkpoint[i], map_location='cpu') load_model_with_checkpoint(replaced_module, sd, mp_replace, ckpt_type) else: - world_size = dist.get_world_size() if dist.is_initialized() else 1 - num_checkpoints = len(checkpoint) // world_size - assert world_size >= len(checkpoint),\ + num_checkpoints = len(checkpoint) // ckpt_mp_size + assert world_size >= ckpt_mp_size,\ "Currently, merging checkpoints is not supported (when world_size is smaller than #checkpoints)!" - checkpoint_stride = world_size // len(checkpoint) + checkpoint_stride = world_size // ckpt_mp_size pbar = tqdm.tqdm(total=num_checkpoints, desc=f"Loading {num_checkpoints} checkpoint shards") for i in range(num_checkpoints): if not deepspeed.comm.is_initialized() or deepspeed.comm.get_rank() == 0: pbar.update(1) - sd = torch.load(checkpoint[i * world_size + (rank // checkpoint_stride)], + sd = torch.load(checkpoint[i * ckpt_mp_size + + (rank // checkpoint_stride)], map_location='cpu') - load_model_with_checkpoint(replaced_module, sd, mp_replace, ckpt_type) + load_model_with_checkpoint(replaced_module, + sd, + mp_replace, + ckpt_type, + rank % (world_size // ckpt_mp_size)) print(f"checkpoint loading time at rank {rank}: {time.time()-start_time} sec") if save_mp_checkpoint_path is not None: + from collections import OrderedDict + import json if dist.is_initialized(): dist.barrier() - torch.save(replaced_module.transformer.state_dict(), - f'{save_mp_checkpoint_path}/bloom-tp_{rank:0>2d}.pt') + transformer_name = get_transformer_name(replaced_module) + non_tp_ckpt_name = f'{save_mp_checkpoint_path}/{ckpt_name}-non-tp.pt' + ckpt_files = [non_tp_ckpt_name] * world_size + if not dist.is_initialized() or dist.get_rank() == 0: + print("Saving tp-sharded checkpoints") + torch.save( + OrderedDict({ + k: v + for k, + v in dict(replaced_module.state_dict()).items() + if transformer_name not in k + }), + non_tp_ckpt_name) + ckpt_files += [ + f'{save_mp_checkpoint_path}/{ckpt_name}-tp_{r:0>2d}.pt' + for r in range(world_size) + ] + config = json.dumps({ + 'type': ckpt_name, + 'checkpoints': ckpt_files, + 'version': 1.0, + 'parallelization': 'tp', + 'mp_size': world_size + }) + with open(f"{save_mp_checkpoint_path}/{ckpt_name}_ds-inference_config.json", + "w") as cfg: + cfg.write(config) + torch.save( + OrderedDict({ + k: v + for k, + v in dict(replaced_module.state_dict()).items() if transformer_name in k + }), + f'{save_mp_checkpoint_path}/{ckpt_name}-tp_{rank:0>2d}.pt') return replaced_module diff --git a/deepspeed/module_inject/replace_policy.py b/deepspeed/module_inject/replace_policy.py index eeb6d613969b..3d5c53275e33 100755 --- a/deepspeed/module_inject/replace_policy.py +++ b/deepspeed/module_inject/replace_policy.py @@ -4,6 +4,8 @@ from torch.nn.parameter import Parameter from packaging import version as pkg_version +supported_models = {None} + class DSPolicy(ABC): def __init__(self, @@ -329,6 +331,9 @@ def __init__(self, client_module, inference=True): try: import transformers BLOOMLayerPolicy._orig_layer_class = transformers.models.bloom.modeling_bloom.BloomBlock + global supported_models + supported_models.update( + {transformers.models.bloom.modeling_bloom.BloomModel}) except: BLOOMLayerPolicy._orig_layer_class = None diff --git a/deepspeed/runtime/state_dict_factory.py b/deepspeed/runtime/state_dict_factory.py index 17cf1ff73906..87850d46c9da 100755 --- a/deepspeed/runtime/state_dict_factory.py +++ b/deepspeed/runtime/state_dict_factory.py @@ -25,8 +25,11 @@ def get_sd_loader_json(json_file): ckpt_type = 'pp' if 'parallelization' in data: ckpt_type = data['parallelization'] + mp_size = 1 + if 'mp_size' in data: + mp_size = data['mp_size'] if 'BLOOM' in sd_type or 'Bloom' in sd_type: - return ckpt_list, ckpt_type + return ckpt_list, ckpt_type, sd_type, mp_size return SDLoaderFactory.get_sd_loader(ckpt_list, sd_type, version) @staticmethod From 557521d51621733e68887b72e9abb21edb1726dc Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Thu, 28 Jul 2022 06:31:13 +0500 Subject: [PATCH 07/15] set ckpt_mp_size to world_size by default --- deepspeed/inference/engine.py | 55 +++++++++++++------------ deepspeed/runtime/state_dict_factory.py | 2 +- 2 files changed, 29 insertions(+), 28 deletions(-) diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index c8990da31c50..5dfbe2c213f7 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -323,33 +323,34 @@ def _apply_injection_policy(self, save_mp_checkpoint_path=False): checkpoint, ckpt_type, ckpt_name, ckpt_mp_size = SDLoaderFactory.get_sd_loader_json( checkpoint_dir) if checkpoint_dir is not None else (None, None, None) - replace_transformer_layer(client_module, - self.module, - triangular_masking=self.triangular_masking, - policy=injection_policy, - mp_size=self.mp_world_size, - mp_group=self.mp_group, - ep_group=self.ep_group, - expert_mp_group=self.expert_mp_group, - config=self.config, - fp16=(self.dtype == torch.half), - training=False, - return_tuple=return_tuple, - quantize=(self.dtype == torch.int8), - quantize_settings=(self.quantization_scales, - self.quantize_merge_count, - self.mlp_extra_grouping, - self.quantize_groups), - replace_with_kernel_inject=replace_with_kernel_inject, - moe=moe, - moe_experts=moe_experts, - moe_type=moe_type, - training_mp_size=training_mp_size, - checkpoint=checkpoint, - ckpt_type=ckpt_type, - save_mp_checkpoint_path=save_mp_checkpoint_path, - ckpt_name=ckpt_name, - ckpt_mp_size=ckpt_mp_size) + replace_transformer_layer( + client_module, + self.module, + triangular_masking=self.triangular_masking, + policy=injection_policy, + mp_size=self.mp_world_size, + mp_group=self.mp_group, + ep_group=self.ep_group, + expert_mp_group=self.expert_mp_group, + config=self.config, + fp16=(self.dtype == torch.half), + training=False, + return_tuple=return_tuple, + quantize=(self.dtype == torch.int8), + quantize_settings=(self.quantization_scales, + self.quantize_merge_count, + self.mlp_extra_grouping, + self.quantize_groups), + replace_with_kernel_inject=replace_with_kernel_inject, + moe=moe, + moe_experts=moe_experts, + moe_type=moe_type, + training_mp_size=training_mp_size, + checkpoint=checkpoint, + ckpt_type=ckpt_type, + save_mp_checkpoint_path=save_mp_checkpoint_path, + ckpt_name=ckpt_name, + ckpt_mp_size=(ckpt_mp_size if ckpt_mp_size > 0 else self.mp_world_size)) def _get_all_ckpt_names(self, checkpoints_path, tag): ckpt_file_pattern = self._get_ckpt_name(checkpoints_path, diff --git a/deepspeed/runtime/state_dict_factory.py b/deepspeed/runtime/state_dict_factory.py index 87850d46c9da..80c26a476849 100755 --- a/deepspeed/runtime/state_dict_factory.py +++ b/deepspeed/runtime/state_dict_factory.py @@ -25,7 +25,7 @@ def get_sd_loader_json(json_file): ckpt_type = 'pp' if 'parallelization' in data: ckpt_type = data['parallelization'] - mp_size = 1 + mp_size = 0 if 'mp_size' in data: mp_size = data['mp_size'] if 'BLOOM' in sd_type or 'Bloom' in sd_type: From b51c44730bda5a63a07cbd48e8b43f139a36e4c7 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Thu, 28 Jul 2022 06:33:32 +0500 Subject: [PATCH 08/15] add missing None --- deepspeed/inference/engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index 5dfbe2c213f7..ad3d3dd9cd5c 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -322,7 +322,7 @@ def _apply_injection_policy(self, checkpoint_dir=None, save_mp_checkpoint_path=False): checkpoint, ckpt_type, ckpt_name, ckpt_mp_size = SDLoaderFactory.get_sd_loader_json( - checkpoint_dir) if checkpoint_dir is not None else (None, None, None) + checkpoint_dir) if checkpoint_dir is not None else (None, None, None, None) replace_transformer_layer( client_module, self.module, From 01003bf9b21063d2c77346df271687e962602337 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Thu, 28 Jul 2022 21:10:55 +0500 Subject: [PATCH 09/15] small fix: change None -> 0 --- deepspeed/inference/engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index ad3d3dd9cd5c..269dd45ab247 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -322,7 +322,7 @@ def _apply_injection_policy(self, checkpoint_dir=None, save_mp_checkpoint_path=False): checkpoint, ckpt_type, ckpt_name, ckpt_mp_size = SDLoaderFactory.get_sd_loader_json( - checkpoint_dir) if checkpoint_dir is not None else (None, None, None, None) + checkpoint_dir) if checkpoint_dir is not None else (None, None, None, 0) replace_transformer_layer( client_module, self.module, From 043cd98e05f74fc8017f1ee89bbe0fed0291312e Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Thu, 28 Jul 2022 22:04:10 +0500 Subject: [PATCH 10/15] fix indentation --- deepspeed/module_inject/replace_module.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 5a7448bfe7f3..74c00a18e6ea 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -784,8 +784,8 @@ def replace_fn(child, _policy, layer_id=0): replace_fn=replace_fn, _replace_policy=policy) - start_time = time.time() if checkpoint is not None: + start_time = time.time() rank = dist.get_rank() if dist.is_initialized() else 0 world_size = dist.get_world_size() if dist.is_initialized() else 1 @@ -815,7 +815,7 @@ def replace_fn(child, _policy, layer_id=0): mp_replace, ckpt_type, rank % (world_size // ckpt_mp_size)) - print(f"checkpoint loading time at rank {rank}: {time.time()-start_time} sec") + print(f"checkpoint loading time at rank {rank}: {time.time()-start_time} sec") if save_mp_checkpoint_path is not None: from collections import OrderedDict From bee0b8fcf0c84c683a49595a9c08d0b0c7cd1d86 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Fri, 29 Jul 2022 00:44:01 +0500 Subject: [PATCH 11/15] several fixes --- deepspeed/inference/engine.py | 16 +++++++----- deepspeed/module_inject/replace_module.py | 31 +++++++++++++---------- deepspeed/runtime/state_dict_factory.py | 2 +- 3 files changed, 27 insertions(+), 22 deletions(-) diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index f2d9c1f7e204..9a9b16f4f10f 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -323,8 +323,12 @@ def _apply_injection_policy(self, training_mp_size=1, checkpoint_dir=None, save_mp_checkpoint_path=False): - checkpoint, ckpt_type, ckpt_name, ckpt_mp_size = SDLoaderFactory.get_sd_loader_json( - checkpoint_dir) if checkpoint_dir is not None else (None, None, None, 0) + checkpoint = SDLoaderFactory.get_sd_loader_json( + checkpoint_dir, + self.checkpoint_engine) if checkpoint_dir is not None else (None, + None, + None, + 0) replace_transformer_layer( client_module, self.module, @@ -348,11 +352,9 @@ def _apply_injection_policy(self, moe_experts=moe_experts, moe_type=moe_type, training_mp_size=training_mp_size, - checkpoint=checkpoint, - ckpt_type=ckpt_type, + checkpoint_dict=checkpoint, save_mp_checkpoint_path=save_mp_checkpoint_path, - ckpt_name=ckpt_name, - ckpt_mp_size=(ckpt_mp_size if ckpt_mp_size > 0 else self.mp_world_size)) + ) def _get_all_ckpt_names(self, checkpoints_path, tag): ckpt_file_pattern = self._get_ckpt_name(checkpoints_path, @@ -392,7 +394,7 @@ def _load_checkpoint(self, load_dir, load_module_strict=True, tag=None): ckpt_list = self._get_all_ckpt_names(load_dir, tag) sd_loader = SDLoaderFactory.get_sd_loader(ckpt_list, self.checkpoint_engine) else: - sd_loader, _, _, _ = SDLoaderFactory.get_sd_loader_json(load_dir) + sd_loader = SDLoaderFactory.get_sd_loader_json(load_dir) if type(sd_loader) is list: self.sd = torch.load(sd_loader[0], map_location='cpu') diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 74c00a18e6ea..91692e968f66 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -160,11 +160,8 @@ def replace_transformer_layer(orig_layer_impl, moe=False, moe_experts=1, moe_type='standard', - checkpoint=None, - ckpt_type='pp', - save_mp_checkpoint_path=None, - ckpt_name=None, - ckpt_mp_size=1): + checkpoint_dict=None, + save_mp_checkpoint_path=None): """ Replace bert-style transformer layers with DeepSpeed's transformer layer Arguments: orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for, @@ -784,10 +781,14 @@ def replace_fn(child, _policy, layer_id=0): replace_fn=replace_fn, _replace_policy=policy) - if checkpoint is not None: + if checkpoint_dict is not None: start_time = time.time() rank = dist.get_rank() if dist.is_initialized() else 0 world_size = dist.get_world_size() if dist.is_initialized() else 1 + checkpoint = checkpoint_dict['checkpoints'] + ckpt_type = checkpoint_dict['parallelization'] + ckpt_mp_size = checkpoint_dict.get('mp_size', mp_size) + base_dir = checkpoint_dict.get('base_dir', None) if ckpt_type == 'pp': pbar = tqdm.tqdm(total=len(checkpoint), @@ -807,9 +808,11 @@ def replace_fn(child, _policy, layer_id=0): for i in range(num_checkpoints): if not deepspeed.comm.is_initialized() or deepspeed.comm.get_rank() == 0: pbar.update(1) - sd = torch.load(checkpoint[i * ckpt_mp_size + - (rank // checkpoint_stride)], - map_location='cpu') + + ckpt_index = i * ckpt_mp_size + (rank // checkpoint_stride) + ckpt_file = checkpoint[ + ckpt_index] if base_dir is None else f'{base_dir}/{checkpoint[ckpt_index]}' + sd = torch.load(ckpt_file, map_location='cpu') load_model_with_checkpoint(replaced_module, sd, mp_replace, @@ -820,10 +823,12 @@ def replace_fn(child, _policy, layer_id=0): if save_mp_checkpoint_path is not None: from collections import OrderedDict import json + + ckpt_name = checkpoint_dict['type'] if dist.is_initialized(): dist.barrier() transformer_name = get_transformer_name(replaced_module) - non_tp_ckpt_name = f'{save_mp_checkpoint_path}/{ckpt_name}-non-tp.pt' + non_tp_ckpt_name = f'{ckpt_name}-non-tp.pt' ckpt_files = [non_tp_ckpt_name] * world_size if not dist.is_initialized() or dist.get_rank() == 0: print("Saving tp-sharded checkpoints") @@ -835,12 +840,10 @@ def replace_fn(child, _policy, layer_id=0): if transformer_name not in k }), non_tp_ckpt_name) - ckpt_files += [ - f'{save_mp_checkpoint_path}/{ckpt_name}-tp_{r:0>2d}.pt' - for r in range(world_size) - ] + ckpt_files += [f'{ckpt_name}-tp_{r:0>2d}.pt' for r in range(world_size)] config = json.dumps({ 'type': ckpt_name, + 'base_dir': f'{save_mp_checkpoint_path}', 'checkpoints': ckpt_files, 'version': 1.0, 'parallelization': 'tp', diff --git a/deepspeed/runtime/state_dict_factory.py b/deepspeed/runtime/state_dict_factory.py index 8751b3f515e4..cb1455254216 100755 --- a/deepspeed/runtime/state_dict_factory.py +++ b/deepspeed/runtime/state_dict_factory.py @@ -32,7 +32,7 @@ def get_sd_loader_json(json_file, checkpoint_engine): if 'mp_size' in data: mp_size = data['mp_size'] if 'BLOOM' in sd_type or 'Bloom' in sd_type: - return ckpt_list, ckpt_type, sd_type, mp_size + return data #ckpt_list, ckpt_type, sd_type, mp_size return SDLoaderFactory.get_sd_loader(ckpt_list, checkpoint_engine, sd_type, From 4d5a4ace924bd05a63e4465928fefacbecf92c54 Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Fri, 29 Jul 2022 02:26:32 +0500 Subject: [PATCH 12/15] support checkpoint as dict or json file --- deepspeed/inference/engine.py | 4 +-- deepspeed/module_inject/replace_module.py | 8 ++++-- deepspeed/runtime/state_dict_factory.py | 34 +++++++++++------------ 3 files changed, 24 insertions(+), 22 deletions(-) diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index 9a9b16f4f10f..3079e68afbbc 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -232,9 +232,9 @@ def _validate_args(self, mpu): for method in methods: if not hasattr(mpu, method): raise ValueError(f"mpu is missing {method}") - if self.checkpoint is not None and not isinstance(self.checkpoint, str): + if self.checkpoint is not None and not isinstance(self.checkpoint, (str, dict)): raise ValueError( - f"checkpoint must be None or a str, got {type(self.checkpoint)}") + f"checkpoint must be None, str or dict, got {type(self.checkpoint)}") supported_dtypes = [None, torch.half, torch.int8, torch.float] if self.dtype not in supported_dtypes: diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 91692e968f66..3c433bd06e2f 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -1,3 +1,4 @@ +import os import torch import tqdm import deepspeed @@ -788,7 +789,7 @@ def replace_fn(child, _policy, layer_id=0): checkpoint = checkpoint_dict['checkpoints'] ckpt_type = checkpoint_dict['parallelization'] ckpt_mp_size = checkpoint_dict.get('mp_size', mp_size) - base_dir = checkpoint_dict.get('base_dir', None) + base_dir = checkpoint_dict.get('base_dir', '') if ckpt_type == 'pp': pbar = tqdm.tqdm(total=len(checkpoint), @@ -810,8 +811,9 @@ def replace_fn(child, _policy, layer_id=0): pbar.update(1) ckpt_index = i * ckpt_mp_size + (rank // checkpoint_stride) - ckpt_file = checkpoint[ - ckpt_index] if base_dir is None else f'{base_dir}/{checkpoint[ckpt_index]}' + ckpt_file = os.path.join( + base_dir, + checkpoint[ckpt_index]) if base_dir else checkpoint[ckpt_index] sd = torch.load(ckpt_file, map_location='cpu') load_model_with_checkpoint(replaced_module, sd, diff --git a/deepspeed/runtime/state_dict_factory.py b/deepspeed/runtime/state_dict_factory.py index cb1455254216..0b720ff471f3 100755 --- a/deepspeed/runtime/state_dict_factory.py +++ b/deepspeed/runtime/state_dict_factory.py @@ -20,23 +20,23 @@ class SDLoaderFactory: @staticmethod def get_sd_loader_json(json_file, checkpoint_engine): - with open(json_file) as f: - data = json.load(f) - sd_type = data['type'] - ckpt_list = data['checkpoints'] - version = data['version'] - ckpt_type = 'pp' - if 'parallelization' in data: - ckpt_type = data['parallelization'] - mp_size = 0 - if 'mp_size' in data: - mp_size = data['mp_size'] - if 'BLOOM' in sd_type or 'Bloom' in sd_type: - return data #ckpt_list, ckpt_type, sd_type, mp_size - return SDLoaderFactory.get_sd_loader(ckpt_list, - checkpoint_engine, - sd_type, - version) + if isinstance(json_file, str): + with open(json_file) as f: + data = json.load(f) + else: + assert isinstance(json_file, dict) + data = json_file + sd_type = data['type'] + ckpt_list = data['checkpoints'] + version = data['version'] + ckpt_type = data.get('parallelization', 'pp') + mp_size = data.get('mp_size', 0) + if 'bloom' in sd_type.lower(): + return data + return SDLoaderFactory.get_sd_loader(ckpt_list, + checkpoint_engine, + sd_type, + version) @staticmethod def get_sd_loader(ckpt_list, checkpoint_engine, sd_type='Megatron', version=None): From b8237c79186f67541fedb3fc46125f1a0f23682e Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Fri, 29 Jul 2022 04:12:05 +0500 Subject: [PATCH 13/15] fix for non-bloom models --- deepspeed/inference/engine.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index 3079e68afbbc..db9efb19dcb1 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -325,10 +325,7 @@ def _apply_injection_policy(self, save_mp_checkpoint_path=False): checkpoint = SDLoaderFactory.get_sd_loader_json( checkpoint_dir, - self.checkpoint_engine) if checkpoint_dir is not None else (None, - None, - None, - 0) + self.checkpoint_engine) if checkpoint_dir is not None else None replace_transformer_layer( client_module, self.module, From 007551ea0547c8ae0c664cf03590bd711c5b99e9 Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Fri, 29 Jul 2022 04:14:04 +0500 Subject: [PATCH 14/15] add default if parallelization doesn't exist --- deepspeed/module_inject/replace_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 3c433bd06e2f..5582bd863ae6 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -787,7 +787,7 @@ def replace_fn(child, _policy, layer_id=0): rank = dist.get_rank() if dist.is_initialized() else 0 world_size = dist.get_world_size() if dist.is_initialized() else 1 checkpoint = checkpoint_dict['checkpoints'] - ckpt_type = checkpoint_dict['parallelization'] + ckpt_type = checkpoint_dict.get('parallelization', 'pp') ckpt_mp_size = checkpoint_dict.get('mp_size', mp_size) base_dir = checkpoint_dict.get('base_dir', '') From 851a6817402815159570f1121c6e453868d343b5 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Fri, 29 Jul 2022 04:25:41 +0500 Subject: [PATCH 15/15] fix the path to save non-tp checkpoint --- deepspeed/inference/engine.py | 5 +---- deepspeed/module_inject/replace_module.py | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index 9a9b16f4f10f..6ea4fab4148f 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -325,10 +325,7 @@ def _apply_injection_policy(self, save_mp_checkpoint_path=False): checkpoint = SDLoaderFactory.get_sd_loader_json( checkpoint_dir, - self.checkpoint_engine) if checkpoint_dir is not None else (None, - None, - None, - 0) + self.checkpoint_engine) if checkpoint_dir is not None else None replace_transformer_layer( client_module, self.module, diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 91692e968f66..9fa9e3685905 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -839,7 +839,7 @@ def replace_fn(child, _policy, layer_id=0): v in dict(replaced_module.state_dict()).items() if transformer_name not in k }), - non_tp_ckpt_name) + f'{save_mp_checkpoint_path}/{non_tp_ckpt_name}') ckpt_files += [f'{ckpt_name}-tp_{r:0>2d}.pt' for r in range(world_size)] config = json.dumps({ 'type': ckpt_name,