diff --git a/llm/finetune_generation.py b/llm/finetune_generation.py index c8fed17165af..6e4123b02df2 100644 --- a/llm/finetune_generation.py +++ b/llm/finetune_generation.py @@ -16,6 +16,7 @@ import sys from dataclasses import dataclass, field from functools import partial +from typing import Optional import paddle from argument import ( @@ -66,6 +67,10 @@ class FinetuneArguments(TrainingArguments): default=0, metadata={"help": "The steps use to control the learing rate."}, ) + tensor_parallel_output: Optional[bool] = field( + default=False, + metadata={"help": "whether to output logits in distributed status"}, + ) def read_local_dataset(path): diff --git a/llm/utils.py b/llm/utils.py index 6688357bd67b..3075943877df 100644 --- a/llm/utils.py +++ b/llm/utils.py @@ -212,7 +212,7 @@ def prediction_step( if isinstance(logits, (list, tuple)): logits = logits[0] # all gather logits when enabling tensor_parallel_output - if self.args.tensor_parallel_degree > 1 and self.args.tensor_parallel_output: + if self.args.tensor_parallel_degree > 1 and getattr(self.args, "tensor_parallel_output", False): hcg = fleet.get_hybrid_communicate_group() model_parallel_group = hcg.get_model_parallel_group() gathered_logits = [] diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index f825c308ebb8..423d77d6f510 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -787,10 +787,6 @@ class TrainingArguments: default=False, metadata={"help": "whether to run distributed training in auto parallel mode"}, ) - tensor_parallel_output: Optional[bool] = field( - default=False, - metadata={"help": "whether to output logits in distributed status"}, - ) use_expert_parallel: Optional[bool] = field( default=False, metadata={"help": "Enable MoE (Mixture of Experts) expert parallel training"},