diff --git a/swift/llm/infer/infer_engine/vllm_engine.py b/swift/llm/infer/infer_engine/vllm_engine.py index 764a4e49d6..96d2818cff 100644 --- a/swift/llm/infer/infer_engine/vllm_engine.py +++ b/swift/llm/infer/infer_engine/vllm_engine.py @@ -430,7 +430,8 @@ def _prepare_generation_config(self, request_config: RequestConfig) -> SamplingP # TODO: beam search for key in ['n', 'best_of', 'frequency_penalty', 'presence_penalty', 'seed']: - kwargs[key] = getattr(request_config, key) + if hasattr(SamplingParams, key): + kwargs[key] = getattr(request_config, key) res = SamplingParams(**kwargs) diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 97ba2ed970..c4cafb865a 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -19,7 +19,6 @@ from megatron.core import mpu from megatron.core.rerun_state_machine import RerunDataIterator from megatron.training import get_args, get_wandb_writer, training -from trl.trainer.grpo_trainer import nanstd from vllm.distributed import parallel_state as vllm_ps from swift.llm import RequestConfig, RolloutInferRequest, RowPreprocessor, Template, to_device @@ -27,7 +26,7 @@ from swift.llm.template.template_inputs import TemplateInputs from swift.plugin import MultiTurnScheduler, multi_turns, orms from swift.trainers.rlhf_trainer.grpo_trainer import DataType -from swift.trainers.rlhf_trainer.utils import (FlattenedTensorBucket, aggressive_empty_cache, +from swift.trainers.rlhf_trainer.utils import (FlattenedTensorBucket, aggressive_empty_cache, nanstd, replace_assistant_response_with_ids, set_expandable_segments) from swift.utils import (get_current_device, get_logger, is_last_rank, is_vllm_available, is_wandb_available, remove_response) diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index fdaed1a326..36ff574cf6 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -1,5 +1,20 @@ # Copyright (c) Alibaba, Inc. and its affiliates. # Part of the implementation is borrowed from huggingface/trl. + +# fmt: off +# apply patch before importing trl, which may internally reference GuidedDecodingParams +try: + import vllm + try: + from vllm.sampling_params import GuidedDecodingParams + except ImportError: + import vllm.sampling_params + # removed in https://github.com/vllm-project/vllm/pull/22772 + vllm.sampling_params.GuidedDecodingParams = vllm.sampling_params.StructuredOutputsParams +except ImportError: + pass +# fmt: on + import concurrent.futures import inspect import os diff --git a/swift/trainers/rlhf_trainer/utils.py b/swift/trainers/rlhf_trainer/utils.py index 0e11ce7f2f..5cb02ba587 100644 --- a/swift/trainers/rlhf_trainer/utils.py +++ b/swift/trainers/rlhf_trainer/utils.py @@ -52,6 +52,25 @@ def embeddings(self): return self.lora_embeddings +def nanstd(tensor: torch.Tensor) -> torch.Tensor: + """ + refer: trl/trainer/utils + Compute the standard deviation of a tensor, ignoring NaNs. This function only supports 1D tensors. + + Args: + tensor (`torch.Tensor`): + Input tensor of shape `(N,)`. + + Returns: + `torch.Tensor`: + Standard deviation of the tensor, ignoring NaNs. + """ + variance = torch.nanmean((tensor - torch.nanmean(tensor, keepdim=True))**2) # Compute variance ignoring NaNs + count = torch.sum(~torch.isnan(tensor)) # Count of non-NaN values + variance *= count / (count - 1) # Bessel's correction + return torch.sqrt(variance) + + # code borrowed from verl/verl/utils/memory_utils.py def aggressive_empty_cache(force_sync: bool = True, max_retries: int = 3) -> None: """