From 6757ff9d436baa20284b411648b6029b3b377e2c Mon Sep 17 00:00:00 2001 From: Siming Dai <908660116@qq.com> Date: Mon, 3 Jun 2024 20:13:07 +0800 Subject: [PATCH] [LLM] relocate tensor_parallel_output to avoid conflict (#8419) (#8533) Co-authored-by: Tian <121000916+SylarTiaNII@users.noreply.github.com> --- llm/finetune_generation.py | 5 +++++ llm/utils.py | 2 +- paddlenlp/trainer/training_args.py | 4 ---- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/llm/finetune_generation.py b/llm/finetune_generation.py index c8fed17165af..6e4123b02df2 100644 --- a/llm/finetune_generation.py +++ b/llm/finetune_generation.py @@ -16,6 +16,7 @@ import sys from dataclasses import dataclass, field from functools import partial +from typing import Optional import paddle from argument import ( @@ -66,6 +67,10 @@ class FinetuneArguments(TrainingArguments): default=0, metadata={"help": "The steps use to control the learing rate."}, ) + tensor_parallel_output: Optional[bool] = field( + default=False, + metadata={"help": "whether to output logits in distributed status"}, + ) def read_local_dataset(path): diff --git a/llm/utils.py b/llm/utils.py index 6688357bd67b..3075943877df 100644 --- a/llm/utils.py +++ b/llm/utils.py @@ -212,7 +212,7 @@ def prediction_step( if isinstance(logits, (list, tuple)): logits = logits[0] # all gather logits when enabling tensor_parallel_output - if self.args.tensor_parallel_degree > 1 and self.args.tensor_parallel_output: + if self.args.tensor_parallel_degree > 1 and getattr(self.args, "tensor_parallel_output", False): hcg = fleet.get_hybrid_communicate_group() model_parallel_group = hcg.get_model_parallel_group() gathered_logits = [] diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index f825c308ebb8..423d77d6f510 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -787,10 +787,6 @@ class TrainingArguments: default=False, metadata={"help": "whether to run distributed training in auto parallel mode"}, ) - tensor_parallel_output: Optional[bool] = field( - default=False, - metadata={"help": "whether to output logits in distributed status"}, - ) use_expert_parallel: Optional[bool] = field( default=False, metadata={"help": "Enable MoE (Mixture of Experts) expert parallel training"},