diff --git a/tests/test_vllm_client_server.py b/tests/test_vllm_client_server.py index 8a675b9c173..611f9a2f781 100644 --- a/tests/test_vllm_client_server.py +++ b/tests/test_vllm_client_server.py @@ -19,7 +19,7 @@ from transformers import AutoModelForCausalLM from transformers.testing_utils import torch_device -from trl.extras.vllm_client import VLLMClient +from trl.generation.vllm_client import VLLMClient from trl.import_utils import is_vllm_available from trl.scripts.vllm_serve import chunk_list diff --git a/trl/experimental/gold/gold_trainer.py b/trl/experimental/gold/gold_trainer.py index 94d394c3f36..ebe74b0d7f2 100644 --- a/trl/experimental/gold/gold_trainer.py +++ b/trl/experimental/gold/gold_trainer.py @@ -48,7 +48,7 @@ from ...data_utils import is_conversational, maybe_convert_to_chatml, pack_dataset, truncate_dataset from ...extras.profiling import profiling_decorator -from ...extras.vllm_client import VLLMClient +from ...generation.vllm_client import VLLMClient from ...import_utils import is_vllm_available from ...models import prepare_deepspeed from ...models.utils import unwrap_model_for_generation diff --git a/trl/experimental/online_dpo/online_dpo_trainer.py b/trl/experimental/online_dpo/online_dpo_trainer.py index f136f473758..af53d455385 100644 --- a/trl/experimental/online_dpo/online_dpo_trainer.py +++ b/trl/experimental/online_dpo/online_dpo_trainer.py @@ -53,7 +53,7 @@ from ...data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template from ...extras.profiling import profiling_context -from ...extras.vllm_client import VLLMClient +from ...generation.vllm_client import VLLMClient from ...import_utils import is_vllm_available from ...models.utils import create_reference_model, prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation from ...trainer.base_trainer import BaseTrainer diff --git a/trl/extras/vllm_client.py b/trl/generation/vllm_client.py similarity index 99% rename from trl/extras/vllm_client.py rename to trl/generation/vllm_client.py index a5e3fcd874e..23893bf0816 100644 --- a/trl/extras/vllm_client.py +++ b/trl/generation/vllm_client.py @@ -28,7 +28,7 @@ from transformers import is_torch_xpu_available from urllib3.util.retry import Retry -from ..import_utils import is_requests_available, is_vllm_ascend_available, is_vllm_available +from import_utils import is_requests_available, is_vllm_ascend_available, is_vllm_available if is_requests_available(): @@ -88,7 +88,7 @@ class VLLMClient: Use the client to generate completions and update model weights: ```python - >>> from trl.extras.vllm_client import VLLMClient + >>> from trl.generation.vllm_client import VLLMClient >>> client = VLLMClient() >>> client.generate(["Hello, AI!", "Tell me a joke"]) diff --git a/trl/generation/vllm_generation.py b/trl/generation/vllm_generation.py index 0eaa734a717..3141582a3a5 100644 --- a/trl/generation/vllm_generation.py +++ b/trl/generation/vllm_generation.py @@ -29,9 +29,9 @@ from ..data_utils import apply_chat_template, is_conversational, prepare_multimodal_messages_vllm from ..extras.profiling import ProfilingContext, profiling_decorator -from ..extras.vllm_client import VLLMClient from ..import_utils import is_vllm_available from ..trainer.utils import ensure_master_addr_port +from .vllm_client import VLLMClient if TYPE_CHECKING: