Skip to content

Commit 95899dd

Browse files
committed
re-formatted
1 parent 18ae428 commit 95899dd

File tree

1 file changed

+60
-41
lines changed

1 file changed

+60
-41
lines changed

vllm/entrypoints/llm.py

Lines changed: 60 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -351,15 +351,16 @@ def generate(
351351

352352
def chat(
353353
self,
354-
messages: List[ChatCompletionMessageParam],
354+
conversations: Union[List[ChatCompletionMessageParam],
355+
List[List[ChatCompletionMessageParam]]],
355356
sampling_params: Optional[Union[SamplingParams,
356357
List[SamplingParams]]] = None,
357358
use_tqdm: bool = True,
358359
lora_request: Optional[LoRARequest] = None,
359360
chat_template: Optional[str] = None,
360361
add_generation_prompt: bool = True,
361362
tools: Optional[List[Dict[str, Any]]] = None,
362-
) -> List[RequestOutput]:
363+
) -> Union[List[List[RequestOutput]], List[RequestOutput]]:
363364
"""
364365
Generate responses for a chat conversation.
365366
@@ -371,8 +372,9 @@ def chat(
371372
to the OpenAI API.
372373
373374
Args:
374-
messages: A single conversation represented as a list of messages.
375-
Each message is a dictionary with 'role' and 'content' keys.
375+
conversations: A list or a single conversation represented as a list
376+
of messages. Each message is a dictionary with 'role' and
377+
'content' keys.
376378
sampling_params: The sampling parameters for text generation.
377379
If None, we use the default sampling parameters. When it
378380
is a single value, it is applied to every prompt. When it
@@ -386,49 +388,66 @@ def chat(
386388
to each message.
387389
388390
Returns:
389-
A list of ``RequestOutput`` objects containing the generated
390-
responses in the same order as the input messages.
391+
A list of lists or single list of ``RequestOutput`` objects
392+
containing the generated responses in the same order as the input
393+
conversations and messages.
391394
"""
395+
list_of_conversations: List[List[ChatCompletionMessageParam]]
392396

393-
tokenizer = self.get_tokenizer()
394-
model_config = self.llm_engine.get_model_config()
395-
396-
conversation, mm_data = parse_chat_messages(messages, model_config,
397-
tokenizer)
398-
399-
prompt: Union[str, List[int]]
400-
if isinstance(tokenizer, MistralTokenizer):
401-
prompt = apply_mistral_chat_template(
402-
tokenizer,
403-
messages=messages,
404-
chat_template=chat_template,
405-
add_generation_prompt=add_generation_prompt,
406-
tools=tools,
407-
)
397+
# Handle multi and single conversations
398+
if is_list_of(conversations, list):
399+
# conversations is List[List[...]]
400+
list_of_conversations = conversations
408401
else:
409-
prompt = apply_hf_chat_template(
410-
tokenizer,
411-
conversation=conversation,
412-
chat_template=chat_template,
413-
add_generation_prompt=add_generation_prompt,
414-
tools=tools,
415-
)
402+
# conversations is List[...]
403+
list_of_conversations = [conversations]
404+
405+
outputs = []
406+
407+
for messages in list_of_conversations:
408+
tokenizer = self.get_tokenizer()
409+
model_config = self.llm_engine.get_model_config()
410+
411+
conversation, mm_data = parse_chat_messages(
412+
messages, model_config, tokenizer)
413+
414+
prompt: Union[str, List[int]]
415+
if isinstance(tokenizer, MistralTokenizer):
416+
prompt = apply_mistral_chat_template(
417+
tokenizer,
418+
messages=messages,
419+
chat_template=chat_template,
420+
add_generation_prompt=add_generation_prompt,
421+
tools=tools,
422+
)
423+
else:
424+
prompt = apply_hf_chat_template(
425+
tokenizer,
426+
conversation=conversation,
427+
chat_template=chat_template,
428+
add_generation_prompt=add_generation_prompt,
429+
tools=tools,
430+
)
431+
432+
inputs: PromptInputs
433+
if is_list_of(prompt, int):
434+
inputs = TokensPrompt(prompt_token_ids=prompt)
435+
else:
436+
inputs = TextPrompt(prompt=prompt)
416437

417-
inputs: PromptInputs
418-
if is_list_of(prompt, int):
419-
inputs = TokensPrompt(prompt_token_ids=prompt)
420-
else:
421-
inputs = TextPrompt(prompt=prompt)
438+
if mm_data is not None:
439+
inputs["multi_modal_data"] = mm_data
422440

423-
if mm_data is not None:
424-
inputs["multi_modal_data"] = mm_data
441+
out = self.generate(
442+
inputs,
443+
sampling_params=sampling_params,
444+
use_tqdm=use_tqdm,
445+
lora_request=lora_request,
446+
)
447+
outputs.append(out)
425448

426-
return self.generate(
427-
inputs,
428-
sampling_params=sampling_params,
429-
use_tqdm=use_tqdm,
430-
lora_request=lora_request,
431-
)
449+
# When conversations is List[...], return a single list.
450+
return outputs if len(outputs) > 1 else outputs[0]
432451

433452
@overload # LEGACY: single (prompt + optional token ids)
434453
def encode(

0 commit comments

Comments
 (0)