Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support custom conversation template in multi_model_worker #2434

Merged
merged 1 commit into from
Sep 18, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion fastchat/serve/multi_model_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,13 @@ def create_multi_model_worker():
action="append",
help="One or more model names. Values must be aligned with `--model-path` values.",
)
parser.add_argument(
"--conv-template",
type=str,
default=None,
action="append",
help="Conversation prompt template. Values must be aligned with `--model-path` values. If only one value is provided, it will be repeated for all models.",
)
parser.add_argument("--limit-worker-concurrency", type=int, default=5)
parser.add_argument("--stream-interval", type=int, default=2)
parser.add_argument("--no-register", action="store_true")
Expand All @@ -201,9 +208,16 @@ def create_multi_model_worker():
if args.model_names is None:
args.model_names = [[x.split("/")[-1]] for x in args.model_path]

if args.conv_template is None:
args.conv_template = [None] * len(args.model_path)
elif len(args.conv_template) == 1: # Repeat the same template
args.conv_template = args.conv_template * len(args.model_path)

# Launch all workers
workers = []
for model_path, model_names in zip(args.model_path, args.model_names):
for conv_template, model_path, model_names in zip(
args.conv_template, args.model_path, args.model_names
):
w = ModelWorker(
args.controller_address,
args.worker_address,
Expand All @@ -219,6 +233,7 @@ def create_multi_model_worker():
cpu_offloading=args.cpu_offloading,
gptq_config=gptq_config,
stream_interval=args.stream_interval,
conv_template=conv_template,
)
workers.append(w)
for model_name in model_names:
Expand Down