Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions nemo_skills/pipeline/utils/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,11 @@ def configure_client(
- server_address: Address of the server.
- extra_arguments: Updated extra arguments for the command.
"""
# Check if user already specified server.server_type in extra_arguments
server_type_arg = "" if "++server.server_type=" in extra_arguments else f"++server.server_type={server_type} "
# Only add server_type if user didn't specify it (allows vllm_multimodal override)
extra_arguments = server_type_arg + extra_arguments

if server_gpus: # we need to host the model
server_port = get_free_port(strategy="random") if get_random_port else 5000
assert server_gpus is not None, "Need to specify server_gpus if hosting the model"
Expand All @@ -567,13 +572,9 @@ def configure_client(
if server_container:
server_config["container"] = server_container
extra_arguments = (
f"{extra_arguments} ++server.server_type={server_type} ++server.host=127.0.0.1 "
f"++server.port={server_port} ++server.model={model} "
f"++server.host=127.0.0.1 ++server.port={server_port} ++server.model={model} {extra_arguments}"
)
else: # model is hosted elsewhere
server_config = None
extra_arguments = (
f"{extra_arguments} ++server.server_type={server_type} "
f"++server.base_url={server_address} ++server.model={model} "
)
extra_arguments = f"++server.base_url={server_address} ++server.model={model} {extra_arguments}"
return server_config, server_address, extra_arguments