Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 2 additions & 14 deletions trl/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
# limitations under the License.

import importlib.resources as resources
import logging
import os
import sys

import torch
from accelerate import logging
from accelerate.commands.launch import launch_command, launch_command_parser

from .scripts.dpo import make_parser as make_dpo_parser
Expand All @@ -32,7 +31,7 @@
from .scripts.vllm_serve import make_parser as make_vllm_serve_parser


logger = logging.get_logger(__name__)
logger = logging.getLogger(__name__)


def main():
Expand Down Expand Up @@ -144,17 +143,6 @@ def main():

elif args.command == "vllm-serve":
(script_args,) = parser.parse_args_and_config()

# Known issue: Using DeepSpeed with tensor_parallel_size=1 and data_parallel_size>1 may cause a crash when
# launched via the CLI. Suggest running the module directly.
# More information: https://github.com/vllm-project/vllm/issues/17079
if script_args.tensor_parallel_size == 1 and script_args.data_parallel_size > 1 and torch.cuda.is_available():
logger.warning(
"Detected configuration: tensor_parallel_size=1 and data_parallel_size>1. This setup is known to "
"cause a crash when using the `trl vllm-serve` CLI entry point. As a workaround, please run the "
"server using the module path instead: `python -m trl.scripts.vllm_serve`",
)

vllm_serve_main(script_args)


Expand Down
Loading