From 4fc6c215f355ee8cd2ab44d0e5f4d43524844d83 Mon Sep 17 00:00:00 2001 From: Tian <121000916+SylarTiaNII@users.noreply.github.com> Date: Sat, 11 May 2024 17:22:39 +0800 Subject: [PATCH] [LLM] relocate tensor_parallel_output to avoid conflict (#8419) --- llm/finetune_generation.py | 5 +++++ llm/utils.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/llm/finetune_generation.py b/llm/finetune_generation.py index c8fed17165af..6e4123b02df2 100644 --- a/llm/finetune_generation.py +++ b/llm/finetune_generation.py @@ -16,6 +16,7 @@ import sys from dataclasses import dataclass, field from functools import partial +from typing import Optional import paddle from argument import ( @@ -66,6 +67,10 @@ class FinetuneArguments(TrainingArguments): default=0, metadata={"help": "The steps use to control the learing rate."}, ) + tensor_parallel_output: Optional[bool] = field( + default=False, + metadata={"help": "whether to output logits in distributed status"}, + ) def read_local_dataset(path): diff --git a/llm/utils.py b/llm/utils.py index 6688357bd67b..3075943877df 100644 --- a/llm/utils.py +++ b/llm/utils.py @@ -212,7 +212,7 @@ def prediction_step( if isinstance(logits, (list, tuple)): logits = logits[0] # all gather logits when enabling tensor_parallel_output - if self.args.tensor_parallel_degree > 1 and self.args.tensor_parallel_output: + if self.args.tensor_parallel_degree > 1 and getattr(self.args, "tensor_parallel_output", False): hcg = fleet.get_hybrid_communicate_group() model_parallel_group = hcg.get_model_parallel_group() gathered_logits = []