@@ -351,15 +351,16 @@ def generate(
351
351
352
352
def chat (
353
353
self ,
354
- messages : List [ChatCompletionMessageParam ],
354
+ conversations : Union [List [ChatCompletionMessageParam ],
355
+ List [List [ChatCompletionMessageParam ]]],
355
356
sampling_params : Optional [Union [SamplingParams ,
356
357
List [SamplingParams ]]] = None ,
357
358
use_tqdm : bool = True ,
358
359
lora_request : Optional [LoRARequest ] = None ,
359
360
chat_template : Optional [str ] = None ,
360
361
add_generation_prompt : bool = True ,
361
362
tools : Optional [List [Dict [str , Any ]]] = None ,
362
- ) -> List [RequestOutput ]:
363
+ ) -> Union [ List [List [ RequestOutput ]], List [ RequestOutput ] ]:
363
364
"""
364
365
Generate responses for a chat conversation.
365
366
@@ -371,8 +372,9 @@ def chat(
371
372
to the OpenAI API.
372
373
373
374
Args:
374
- messages: A single conversation represented as a list of messages.
375
- Each message is a dictionary with 'role' and 'content' keys.
375
+ conversations: A list or a single conversation represented as a list
376
+ of messages. Each message is a dictionary with 'role' and
377
+ 'content' keys.
376
378
sampling_params: The sampling parameters for text generation.
377
379
If None, we use the default sampling parameters. When it
378
380
is a single value, it is applied to every prompt. When it
@@ -386,49 +388,66 @@ def chat(
386
388
to each message.
387
389
388
390
Returns:
389
- A list of ``RequestOutput`` objects containing the generated
390
- responses in the same order as the input messages.
391
+ A list of lists or single list of ``RequestOutput`` objects
392
+ containing the generated responses in the same order as the input
393
+ conversations and messages.
391
394
"""
395
+ list_of_conversations : List [List [ChatCompletionMessageParam ]]
392
396
393
- tokenizer = self .get_tokenizer ()
394
- model_config = self .llm_engine .get_model_config ()
395
-
396
- conversation , mm_data = parse_chat_messages (messages , model_config ,
397
- tokenizer )
398
-
399
- prompt : Union [str , List [int ]]
400
- if isinstance (tokenizer , MistralTokenizer ):
401
- prompt = apply_mistral_chat_template (
402
- tokenizer ,
403
- messages = messages ,
404
- chat_template = chat_template ,
405
- add_generation_prompt = add_generation_prompt ,
406
- tools = tools ,
407
- )
397
+ # Handle multi and single conversations
398
+ if is_list_of (conversations , list ):
399
+ # conversations is List[List[...]]
400
+ list_of_conversations = conversations
408
401
else :
409
- prompt = apply_hf_chat_template (
410
- tokenizer ,
411
- conversation = conversation ,
412
- chat_template = chat_template ,
413
- add_generation_prompt = add_generation_prompt ,
414
- tools = tools ,
415
- )
402
+ # conversations is List[...]
403
+ list_of_conversations = [conversations ]
404
+
405
+ outputs = []
406
+
407
+ for messages in list_of_conversations :
408
+ tokenizer = self .get_tokenizer ()
409
+ model_config = self .llm_engine .get_model_config ()
410
+
411
+ conversation , mm_data = parse_chat_messages (
412
+ messages , model_config , tokenizer )
413
+
414
+ prompt : Union [str , List [int ]]
415
+ if isinstance (tokenizer , MistralTokenizer ):
416
+ prompt = apply_mistral_chat_template (
417
+ tokenizer ,
418
+ messages = messages ,
419
+ chat_template = chat_template ,
420
+ add_generation_prompt = add_generation_prompt ,
421
+ tools = tools ,
422
+ )
423
+ else :
424
+ prompt = apply_hf_chat_template (
425
+ tokenizer ,
426
+ conversation = conversation ,
427
+ chat_template = chat_template ,
428
+ add_generation_prompt = add_generation_prompt ,
429
+ tools = tools ,
430
+ )
431
+
432
+ inputs : PromptInputs
433
+ if is_list_of (prompt , int ):
434
+ inputs = TokensPrompt (prompt_token_ids = prompt )
435
+ else :
436
+ inputs = TextPrompt (prompt = prompt )
416
437
417
- inputs : PromptInputs
418
- if is_list_of (prompt , int ):
419
- inputs = TokensPrompt (prompt_token_ids = prompt )
420
- else :
421
- inputs = TextPrompt (prompt = prompt )
438
+ if mm_data is not None :
439
+ inputs ["multi_modal_data" ] = mm_data
422
440
423
- if mm_data is not None :
424
- inputs ["multi_modal_data" ] = mm_data
441
+ out = self .generate (
442
+ inputs ,
443
+ sampling_params = sampling_params ,
444
+ use_tqdm = use_tqdm ,
445
+ lora_request = lora_request ,
446
+ )
447
+ outputs .append (out )
425
448
426
- return self .generate (
427
- inputs ,
428
- sampling_params = sampling_params ,
429
- use_tqdm = use_tqdm ,
430
- lora_request = lora_request ,
431
- )
449
+ # When conversations is List[...], return a single list.
450
+ return outputs if len (outputs ) > 1 else outputs [0 ]
432
451
433
452
@overload # LEGACY: single (prompt + optional token ids)
434
453
def encode (
0 commit comments