@@ -29,15 +29,15 @@ async def test_model_single_request(tgi_service):
29
29
assert response .generated_text == greedy_expectations [service_name ]
30
30
31
31
# Greedy bounded with input
32
- response = await tgi_service .client .text_generation (
32
+ greedy_response = await tgi_service .client .text_generation (
33
33
"What is Deep Learning?" ,
34
34
max_new_tokens = 17 ,
35
35
return_full_text = True ,
36
36
details = True ,
37
37
decoder_input_details = True ,
38
38
)
39
- assert response .details .generated_tokens == 17
40
- assert response .generated_text == prompt + greedy_expectations [service_name ]
39
+ assert greedy_response .details .generated_tokens == 17
40
+ assert greedy_response .generated_text == prompt + greedy_expectations [service_name ]
41
41
42
42
# Sampling
43
43
response = await tgi_service .client .text_generation (
@@ -52,16 +52,12 @@ async def test_model_single_request(tgi_service):
52
52
# The response must be different
53
53
assert not response .startswith (greedy_expectations [service_name ])
54
54
55
- # Sampling with stop sequence (using one of the words returned from the previous test)
56
- stop_sequence = response .split (" " )[- 5 ]
55
+ # Greedy with stop sequence (using one of the words returned from the previous test)
56
+ stop_sequence = greedy_response . generated_text .split (" " )[- 5 ]
57
57
response = await tgi_service .client .text_generation (
58
58
"What is Deep Learning?" ,
59
- do_sample = True ,
60
- top_k = 50 ,
61
- top_p = 0.9 ,
62
- repetition_penalty = 1.2 ,
59
+ do_sample = False ,
63
60
max_new_tokens = 128 ,
64
- seed = 42 ,
65
61
stop_sequences = [stop_sequence ],
66
62
)
67
63
assert response .endswith (stop_sequence )
0 commit comments