|
12 | 12 |
|
13 | 13 | # Define models, templates, and their corresponding expected outputs
|
14 | 14 | MODEL_TEMPLATE_GENERATON_OUTPUT = [
|
15 |
| - ("facebook/opt-125m", chatml_jinja_path, True, """<|im_start|>user |
| 15 | + ("facebook/opt-125m", chatml_jinja_path, True, False, """<|im_start|>user |
16 | 16 | Hello<|im_end|>
|
17 | 17 | <|im_start|>assistant
|
18 | 18 | Hi there!<|im_end|>
|
19 | 19 | <|im_start|>user
|
20 | 20 | What is the capital of<|im_end|>
|
21 | 21 | <|im_start|>assistant
|
22 | 22 | """),
|
23 |
| - ("facebook/opt-125m", chatml_jinja_path, False, """<|im_start|>user |
| 23 | + ("facebook/opt-125m", chatml_jinja_path, False, False, """<|im_start|>user |
24 | 24 | Hello<|im_end|>
|
25 | 25 | <|im_start|>assistant
|
26 | 26 | Hi there!<|im_end|>
|
27 | 27 | <|im_start|>user
|
28 |
| -What is the capital of""") |
| 28 | +What is the capital of"""), |
| 29 | + ("facebook/opt-125m", chatml_jinja_path, False, True, """<|im_start|>user |
| 30 | +Hello<|im_end|> |
| 31 | +<|im_start|>assistant |
| 32 | +Hi there!<|im_end|> |
| 33 | +<|im_start|>user |
| 34 | +What is the capital of<|im_end|> |
| 35 | +<|im_start|>assistant |
| 36 | +The capital of"""), |
29 | 37 | ]
|
30 | 38 |
|
31 | 39 | TEST_MESSAGES = [
|
|
42 | 50 | 'content': 'What is the capital of'
|
43 | 51 | },
|
44 | 52 | ]
|
| 53 | +ASSISTANT_MESSAGE_TO_CONTINUE = { |
| 54 | + 'role': 'assistant', |
| 55 | + 'content': 'The capital of' |
| 56 | +} |
45 | 57 |
|
46 | 58 |
|
47 | 59 | def test_load_chat_template():
|
@@ -73,26 +85,30 @@ def test_no_load_chat_template_literallike():
|
73 | 85 |
|
74 | 86 |
|
75 | 87 | @pytest.mark.parametrize(
|
76 |
| - "model,template,add_generation_prompt,expected_output", |
| 88 | + "model,template,add_generation_prompt,continue_final_message,expected_output", |
77 | 89 | MODEL_TEMPLATE_GENERATON_OUTPUT)
|
78 | 90 | def test_get_gen_prompt(model, template, add_generation_prompt,
|
79 |
| - expected_output): |
| 91 | + continue_final_message, expected_output): |
80 | 92 | # Initialize the tokenizer
|
81 | 93 | tokenizer = get_tokenizer(tokenizer_name=model)
|
82 | 94 | template_content = load_chat_template(chat_template=template)
|
83 | 95 |
|
84 | 96 | # Create a mock request object using keyword arguments
|
85 | 97 | mock_request = ChatCompletionRequest(
|
86 | 98 | model=model,
|
87 |
| - messages=TEST_MESSAGES, |
88 |
| - add_generation_prompt=add_generation_prompt) |
| 99 | + messages=TEST_MESSAGES + [ASSISTANT_MESSAGE_TO_CONTINUE] |
| 100 | + if continue_final_message else TEST_MESSAGES, |
| 101 | + add_generation_prompt=add_generation_prompt, |
| 102 | + continue_final_message=continue_final_message, |
| 103 | + ) |
89 | 104 |
|
90 | 105 | # Call the function and get the result
|
91 | 106 | result = apply_hf_chat_template(
|
92 | 107 | tokenizer,
|
93 | 108 | conversation=mock_request.messages,
|
94 | 109 | chat_template=mock_request.chat_template or template_content,
|
95 | 110 | add_generation_prompt=mock_request.add_generation_prompt,
|
| 111 | + continue_final_message=mock_request.continue_final_message, |
96 | 112 | )
|
97 | 113 |
|
98 | 114 | # Test assertion
|
|
0 commit comments