diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py index 3b1d313b6ea2b..a6e8661f7a6ea 100755 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py @@ -657,21 +657,42 @@ def save_state_dict(self, path): if self._topo.get_coord(self.global_rank).data != 0: return - def _offset_dirname(ckpt_dir, local_layer_idx): - idx = local_layer_idx + self._start_pos + def _offset_dirname(ckpt_dir, local_layer_idx, local_chunk_id=None): + if self._num_virtual_pipeline_stages == 1: + pos_offset = self._start_pos + else: + assert hasattr(self, '_start_poss') + assert local_chunk_id < len(self._start_poss) + pos_offset = self._start_poss[local_chunk_id] + idx = local_layer_idx + pos_offset model_rank = self._topo.get_coord(self.global_rank).model rank_message = "-tensor_" + "{:0>2d}".format(model_rank) + virtual_pipeline_stage_message = "" + if self._num_virtual_pipeline_stages > 1: + # add virtual pipeline info to the save path + assert local_chunk_id is not None + virtual_pipeline_stage_message = "-virtual_pp_stage_{:0>2d}".format( + local_chunk_id) layer_save_path = os.path.join(ckpt_dir, 'layer_{:0>2d}'.format(idx)) - layer_save_path = layer_save_path + rank_message + '-model_states.pdparams' + layer_save_path = layer_save_path + virtual_pipeline_stage_message + rank_message + '-model_states.pdparams' return layer_save_path + def _save_model(run_functions, local_chunk_id=None): + for idx, layer in enumerate(run_functions): + model_save_path = _offset_dirname(path, idx, local_chunk_id) + if not hasattr(layer, 'state_dict'): + continue + paddle.save(layer.state_dict(), model_save_path) + os.makedirs(path, exist_ok=True) - for idx, layer in enumerate(self.run_function): - model_save_path = _offset_dirname(path, idx) - if not hasattr(layer, 'state_dict'): - continue - paddle.save(layer.state_dict(), model_save_path) + if self._num_virtual_pipeline_stages > 1: + logger.info("save model state for virtual pipeline stage...") + for chunk_id in range(len(self._model_chunks)): + run_function = self._model_chunks[chunk_id].get_run_function() + _save_model(run_function, chunk_id) + else: + _save_model(self.run_function) logger.info("save model state successfully...") @@ -679,21 +700,43 @@ def set_state_dir(self, path): assert os.path.exists( path), "{} not found, please check the path".format(path) - for idx, layer in enumerate(self.run_function): - if not hasattr(layer, 'set_state_dict'): - continue - layer_idx = idx + self._start_pos - layer_save_path = os.path.join(path, - 'layer_{0:0>2d}'.format(layer_idx)) - model_files = glob.glob(layer_save_path + "*model_states.pdparams") - model_files.sort() - mp_rank = self._topo.get_coord(self.global_rank).model - mp_world_size = self._topo.get_dim('model') - num_files = len(model_files) - - load_param_path = model_files[mp_rank * num_files // mp_world_size] - model_state_dict = paddle.load(load_param_path) - layer.set_state_dict(model_state_dict) + def _load_model(run_functions, local_chunk_id=None): + for idx, layer in enumerate(run_functions): + if not hasattr(layer, 'set_state_dict'): + continue + if self._num_virtual_pipeline_stages == 1: + pos_offset = self._start_pos + else: + assert hasattr(self, '_start_poss') + assert local_chunk_id < len(self._start_poss) + pos_offset = self._start_poss[local_chunk_id] + layer_idx = idx + pos_offset + layer_save_path = os.path.join( + path, 'layer_{0:0>2d}'.format(layer_idx)) + if self._num_virtual_pipeline_stages > 1: + # add virtual pipeline info to the path + assert local_chunk_id is not None + layer_save_path = layer_save_path + "-virtual_pp_stage_{:0>2d}".format( + local_chunk_id) + model_files = glob.glob(layer_save_path + + "*model_states.pdparams") + model_files.sort() + mp_rank = self._topo.get_coord(self.global_rank).model + mp_world_size = self._topo.get_dim('model') + num_files = len(model_files) + + load_param_path = model_files[mp_rank * num_files // + mp_world_size] + model_state_dict = paddle.load(load_param_path) + layer.set_state_dict(model_state_dict) + + if self._num_virtual_pipeline_stages > 1: + logger.info("load model state for virtual pipeline stage...") + for chunk_id in range(len(self._model_chunks)): + run_function = self._model_chunks[chunk_id].get_run_function() + _load_model(run_function, chunk_id) + else: + _load_model(self.run_function) self._synchronize_shared_weights() logger.info("load model state successfully...") diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_pp_save_load_with_virtual_stage.py b/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_pp_save_load_with_virtual_stage.py new file mode 100644 index 0000000000000..6569a6ef0a13d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_pp_save_load_with_virtual_stage.py @@ -0,0 +1,117 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import unittest +import paddle +import numpy as np +import random +import os +import shutil +import tempfile +import paddle.distributed as dist +import paddle.distributed.fleet as fleet +from hybrid_parallel_pp_transformer_with_virtual_stage import ModelPipe, set_random_seed + +batch_size = 8 +length = 8 +micro_batch_size = 2 +vocab_size = 128 + + +class TestDistPPSaveLoadTraning(unittest.TestCase): + + def setUp(self): + strategy = fleet.DistributedStrategy() + self.model_parallel_size = 1 + self.data_parallel_size = 1 + self.pipeline_parallel_size = 2 + strategy.hybrid_configs = { + "dp_degree": self.data_parallel_size, + "mp_degree": self.model_parallel_size, + "pp_degree": self.pipeline_parallel_size, + } + strategy.pipeline_configs = { + "accumulate_steps": batch_size // micro_batch_size, + "micro_batch_size": micro_batch_size + } + fleet.init(is_collective=True, strategy=strategy) + + def test_pp_model(self): + hcg = fleet.get_hybrid_communicate_group() + word_size = hcg.get_model_parallel_world_size() + dp_id = hcg.get_data_parallel_rank() + pp_id = hcg.get_stage_id() + rank_id = dist.get_rank() + topology = hcg.topology() + set_random_seed(1024, dp_id, rank_id) + + model = ModelPipe(topology) + scheduler = paddle.optimizer.lr.PiecewiseDecay(boundaries=[2], + values=[0.001, 0.002], + verbose=True) + optimizer = paddle.optimizer.SGD(learning_rate=scheduler, + parameters=model.parameters()) + + model = fleet.distributed_model(model) + optimizer = fleet.distributed_optimizer(optimizer) + output_dir = tempfile.mkdtemp() + + # warmup step + for step_id in range(2): + x_data = np.random.randint(0, vocab_size, size=[batch_size, length]) + x = paddle.to_tensor(x_data) + x.stop_gradient = True + loss = model.train_batch([x, x], optimizer, scheduler) + + model._layers.save_state_dict(output_dir) + paddle.save(optimizer.state_dict(), + os.path.join(output_dir, "model_state.pdopt")) + + # construct data + test_steps = 5 + np_data = np.random.randint(0, + vocab_size, + size=[test_steps, batch_size, length]) + + origin_loss = [] + for step_id in range(5): + x_data = np_data[step_id, :] + x = paddle.to_tensor(x_data) + x.stop_gradient = True + loss = model.train_batch([x, x], optimizer, scheduler) + origin_loss.append(loss.numpy()) + + # test step + model._layers.set_state_dir(output_dir) + opt_dict = paddle.load(os.path.join(output_dir, "model_state.pdopt")) + optimizer.set_state_dict(opt_dict) + + for step_id in range(5): + x_data = np_data[step_id, :] + x = paddle.to_tensor(x_data) + x.stop_gradient = True + loss = model.train_batch([x, x], optimizer, scheduler) + print("origin loss: ", origin_loss[step_id], "current loss: ", + loss.numpy()) + np.testing.assert_allclose(loss.numpy(), origin_loss[step_id]) + + # finally, remove the model/optimizer path + shutil.rmtree(output_dir) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_pp_transformer.py b/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_pp_transformer.py index ffe4a063a9ccf..1e13404e69de2 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_pp_transformer.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_pp_transformer.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_parallel_dygraph_pipeline_parallel_with_virtual_stage.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_parallel_dygraph_pipeline_parallel_with_virtual_stage.py index 2330e0dac6cac..643aba4450bcc 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/test_parallel_dygraph_pipeline_parallel_with_virtual_stage.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_parallel_dygraph_pipeline_parallel_with_virtual_stage.py @@ -30,6 +30,10 @@ def test_hybrid_parallel_pp_transformer_with_virtual_stage(self): self.run_mnist_2gpu( 'hybrid_parallel_pp_transformer_with_virtual_stage.py') + def test_hybrid_parallel_save_load_with_virtual_stage(self): + self.run_mnist_2gpu( + 'hybrid_parallel_pp_save_load_with_virtual_stage.py') + if __name__ == "__main__": os.environ["FLAGS_enable_eager_mode"] = "1"