Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion fastchat/serve/multi_model_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,13 @@ def create_multi_model_worker():
action="append",
help="One or more model names. Values must be aligned with `--model-path` values.",
)
parser.add_argument(
"--conv-template",
type=str,
default=None,
action="append",
help="Conversation prompt template. Values must be aligned with `--model-path` values. If only one value is provided, it will be repeated for all models.",
)
parser.add_argument("--limit-worker-concurrency", type=int, default=5)
parser.add_argument("--stream-interval", type=int, default=2)
parser.add_argument("--no-register", action="store_true")
Expand All @@ -201,9 +208,16 @@ def create_multi_model_worker():
if args.model_names is None:
args.model_names = [[x.split("/")[-1]] for x in args.model_path]

if args.conv_template is None:
args.conv_template = [None] * len(args.model_path)
elif len(args.conv_template) == 1: # Repeat the same template
args.conv_template = args.conv_template * len(args.model_path)

# Launch all workers
workers = []
for model_path, model_names in zip(args.model_path, args.model_names):
for conv_template, model_path, model_names in zip(
args.conv_template, args.model_path, args.model_names
):
w = ModelWorker(
args.controller_address,
args.worker_address,
Expand All @@ -219,6 +233,7 @@ def create_multi_model_worker():
cpu_offloading=args.cpu_offloading,
gptq_config=gptq_config,
stream_interval=args.stream_interval,
conv_template=conv_template,
)
workers.append(w)
for model_name in model_names:
Expand Down