Skip to content

Commit

Permalink
Add merge tp trainer (PaddlePaddle#5491)
Browse files Browse the repository at this point in the history
* add merge tp optional in trainer.

* fix bugs.

* add warning.

* Add set hyrbid parallel seed.

* tmp

* fix the default value for the bloom

* fix the prompt trainer save

---------

Co-authored-by: Zhong Hui <[email protected]>
  • Loading branch information
wawltor and ZHUI authored Apr 3, 2023
1 parent 7e8e2ab commit 68b21f1
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 34 deletions.
2 changes: 1 addition & 1 deletion examples/language_model/glm/finetune_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def compute_metrics(eval_preds):

if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=last_checkpoint)
trainer.save_model()
trainer.save_model(merge_tensor_parallel=training_args.tensor_parallel_degree > 1)
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
Expand Down
9 changes: 7 additions & 2 deletions paddlenlp/prompt/prompt_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,13 @@ def encode_with_template(example):
def _prepare_input(self, inputs: Dict):
return inputs

def _save(self, output_dir: Optional[str] = None, state_dict: Dict[str, Any] = None):
super(PromptTrainer, self)._save(output_dir, state_dict)
def _save(
self,
output_dir: Optional[str] = None,
state_dict: Dict[str, Any] = None,
merge_tensor_parallel: Optional[bool] = True,
):
super(PromptTrainer, self)._save(output_dir, state_dict, merge_tensor_parallel)
output_dir = output_dir if output_dir is not None else self.args.output_dir
if self.template:
self.template.save(output_dir)
Expand Down
40 changes: 27 additions & 13 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
get_last_checkpoint,
get_scheduler,
has_length,
set_hyrbid_parallel_seed,
set_seed,
speed_metrics,
)
Expand Down Expand Up @@ -212,6 +213,13 @@ def __init__(

# Seed must be set before instantiating the model when using model
set_seed(self.args.seed)
if self.args.use_hybrid_parallel:
set_hyrbid_parallel_seed(
basic_seed=self.args.seed,
dataset_rank=self.args.dataset_rank,
tp_rank=self.args.tensor_parallel_rank,
)

if model is None:
raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument")

Expand Down Expand Up @@ -527,8 +535,15 @@ def train(
logger.info(f" Total optimization steps = {max_steps}")
logger.info(f" Total num train samples = {num_train_samples}")
logger.info(
f" Number of trainable parameters = {sum(p.numel().item() for p in model.parameters() if not p.stop_gradient) }"
f" Number of trainable parameters = {sum(p.numel().item() for p in model.parameters() if not p.stop_gradient)} (per device)"
)
if self.args.use_hybrid_parallel and self.args.tensor_parallel_degree > 1:
# todo fix for pipeline_parallel_degree
logger.info(
" Number of trainable parameters = "
f"{sum(p.numel().item() for p in model.parameters() if not p.stop_gradient) * self.args.tensor_parallel_degree}"
" (all devices, roughly)"
)

start_time = time.time()
self._globalstep_last_start_time = time.time()
Expand Down Expand Up @@ -1367,7 +1382,7 @@ def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor,

return loss.detach()

def save_model(self, output_dir: Optional[str] = None):
def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Optional[bool] = False):
"""
Will save the model, so you can reload it using `from_pretrained()`.
Expand All @@ -1378,7 +1393,7 @@ def save_model(self, output_dir: Optional[str] = None):
output_dir = self.args.output_dir

if self.args.should_save:
self._save(output_dir)
self._save(output_dir=output_dir, merge_tensor_parallel=merge_tensor_parallel)

def _save_checkpoint(self, model, metrics=None):
# assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
Expand Down Expand Up @@ -1506,32 +1521,31 @@ def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None:
logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
shutil.rmtree(checkpoint)

def _save(self, output_dir: Optional[str] = None, state_dict=None):
def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_parallel=False):
# If we are executing this function, we are the process zero, so we don't check for that.
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Saving model checkpoint to {output_dir}")
# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`

merge_tensor_parallel = merge_tensor_parallel and self.args.use_hybrid_parallel

if not isinstance(self.model, PretrainedModel) and not isinstance(self.model, LoRAModel):
if isinstance(unwrap_model(self.model), PretrainedModel):

# unwrap_model(self.model).save_pretrained(
# output_dir, state_dict=state_dict)
if self.args.use_hybrid_parallel:
unwrap_model(self.model).resource_files_names["model_state"] = _add_variant(
WEIGHTS_NAME, self.args.weight_name_suffix
)
unwrap_model(self.model).save_pretrained(output_dir)
unwrap_model(self.model).save_pretrained(output_dir, merge_tensor_parallel=merge_tensor_parallel)
else:
logger.info("Trainer.model is not a `PretrainedModel`, only saving its state dict.")
if merge_tensor_parallel:
logger.warning("Trainer.model is not a `PretrainedModel`, not suppor for merge_tensor_parallel.")
if state_dict is None:
state_dict = self.model.state_dict()
paddle.save(
state_dict, os.path.join(output_dir, _add_variant(WEIGHTS_NAME, self.args.weight_name_suffix))
)
else:
self.model.save_pretrained(output_dir)
self.model.save_pretrained(output_dir, merge_tensor_parallel=merge_tensor_parallel)

if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)

Expand Down
15 changes: 15 additions & 0 deletions paddlenlp/trainer/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,3 +880,18 @@ def _remove_columns(self, feature: dict) -> dict:
def __call__(self, features: List[dict]):
features = [self._remove_columns(feature) for feature in features]
return self.data_collator(features)


def set_hyrbid_parallel_seed(basic_seed, dataset_rank, tp_rank, pp_rank=0):
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker

random.seed(basic_seed + dataset_rank)
np.random.seed(basic_seed + dataset_rank)
paddle.seed(basic_seed + dataset_rank)

# local_seed/ global_seed is used to control dropout in ModelParallel
local_seed = basic_seed + 123 + tp_rank * 10 + pp_rank * 1000
global_seed = basic_seed + dataset_rank
tracker = get_rng_state_tracker()
tracker.add("global_seed", global_seed)
tracker.add("local_seed", local_seed)
35 changes: 17 additions & 18 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,9 +627,6 @@ def __post_init__(self):
if len(self.sharding) == 0 and self.sharding_parallel_degree > 0:
warnings.warn("`--sharding_parallel_degree` is useful only when `--sharding` is specified.")

if self.tensor_parallel_degree <= 1:
self.tensor_parallel_degree = 1

if len(self.sharding) > 0 or self.tensor_parallel_degree > 1:
self.use_hybrid_parallel = True

Expand All @@ -639,14 +636,20 @@ def __post_init__(self):
world_size % self.tensor_parallel_degree == 0
), f"Total world_size:{world_size} shoule be devided by tensor_parallel_degree:{self.tensor_parallel_degree}."

tensor_parallel_degree = max(self.tensor_parallel_degree, 1)

if self.sharding_parallel_degree == -1:
self.sharding_parallel_degree = world_size // self.tensor_parallel_degree
if len(self.sharding) > 0:
self.sharding_parallel_degree = world_size // self.tensor_parallel_degree

sharding_parallel_degree = max(self.sharding_parallel_degree, 1)

assert world_size % (self.sharding_parallel_degree * self.tensor_parallel_degree) == 0, (
assert world_size % (sharding_parallel_degree * tensor_parallel_degree) == 0, (
"The world size for workers should be divided by sharding_parallel_degree and tensor_parallel_degree, "
"sharding_parallel_degree:{sharding_parallel_degree}, tensor_parallel_degree:{tensor_parallel_degree}, world_size:{self.world_size}"
"sharding_parallel_degree:{sharding_parallel_degree}, tensor_parallel_degree:{tensor_parallel_degree},"
" world_size:{world_size}"
)
self.data_parallel_degree = world_size // (self.sharding_parallel_degree * self.tensor_parallel_degree)
self.data_parallel_degree = world_size // (sharding_parallel_degree * tensor_parallel_degree)

if ShardingOption.OFFLOAD in self.sharding or ShardingOption.FULL_SHARD in self.sharding:
warnings.warn("`offload` and `stage3` is not supported NOW!")
Expand All @@ -656,9 +659,9 @@ def __post_init__(self):
strategy = fleet.DistributedStrategy()
strategy.hybrid_configs = {
"dp_degree": self.data_parallel_degree,
"mp_degree": self.tensor_parallel_degree,
"mp_degree": tensor_parallel_degree,
"pp_degree": 1,
"sharding_degree": self.sharding_parallel_degree,
"sharding_degree": sharding_parallel_degree,
}
fleet.init(is_collective=True, strategy=strategy)
logger.info(strategy)
Expand Down Expand Up @@ -751,14 +754,14 @@ def data_parallel_rank(self):
@property
def dataset_rank(self):
if self.use_hybrid_parallel:
return self.sharding_parallel_degree * self.data_parallel_rank + self.sharding_parallel_rank
return max(self.sharding_parallel_degree, 1) * self.data_parallel_rank + self.sharding_parallel_rank
else:
return paddle.distributed.get_rank()

@property
def dataset_world_size(self):
if self.use_hybrid_parallel:
return self.sharding_parallel_degree * self.data_parallel_degree
return max(self.sharding_parallel_degree, 1) * max(self.data_parallel_degree, 1)
else:
return paddle.distributed.get_world_size()

Expand All @@ -767,20 +770,16 @@ def sharding_parallel_rank(self):
if self.use_hybrid_parallel:
hcg = fleet.get_hybrid_communicate_group()
sharding_group = hcg.get_sharding_parallel_group()
if sharding_group.rank < 0:
return 0
return sharding_group.rank
return max(sharding_group.rank, 0)
else:
return 0

@property
def tensor_parallel_rank(self):
if self.use_hybrid_parallel:
hcg = fleet.get_hybrid_communicate_group()
mp_group = hcg.get_model_parallel_group()
if mp_group.rank < 0:
return 0
return mp_group.rank
tp_group = hcg.get_model_parallel_group()
return max(tp_group.rank, 0)
else:
return 0

Expand Down

0 comments on commit 68b21f1

Please sign in to comment.