1
- import pytest
2
1
import torch
3
2
from utils .llm_data import llm_models_root
4
3
from utils .util import skip_gpu_memory_less_than
@@ -238,15 +237,15 @@ def test_nemotron_h_correctness():
238
237
nemotron_h .shutdown ()
239
238
240
239
241
- @pytest .mark .skip (reason = "https://nvbugs/5404046" )
242
240
def test_nemotron_h_cuda_graph_overlap_scheduler ():
243
241
prompts = [
244
- "Tell me something I don't know about the future of AI " ,
245
- "The president of the United States is" ,
246
- "The capital of France is " ,
247
- "Hello, this is a beautiful day and I'm eager to start my day and " ,
242
+ "The sky is blue because " ,
243
+ "The sum of two and two is" ,
244
+ "The largest mammal is the " ,
245
+ "The chemical symbol for water is " ,
248
246
]
249
- sampling_config = SamplingParams (max_tokens = 12 ,
247
+
248
+ sampling_config = SamplingParams (max_tokens = 10 ,
250
249
temperature = 0.0 ,
251
250
return_generation_logits = True )
252
251
@@ -273,32 +272,46 @@ def test_nemotron_h_cuda_graph_overlap_scheduler():
273
272
prompts , sampling_params = sampling_config , use_tqdm = True )
274
273
275
274
# Verify outputs are consistent
276
- for (no_cg_no_overlap , with_cg_no_overlap ,
277
- with_cg_with_overlap ) in zip (outputs_no_cg_no_overlap ,
278
- outputs_with_cg_no_overlap ,
279
- outputs_with_cg_with_overlap ):
280
-
281
- assert (no_cg_no_overlap .outputs [0 ].text ==
282
- with_cg_no_overlap .outputs [0 ].text )
283
- assert (with_cg_no_overlap .outputs [0 ].text ==
284
- with_cg_with_overlap .outputs [0 ].text )
275
+ for i , (no_cg_no_overlap , with_cg_no_overlap ,
276
+ with_cg_with_overlap ) in enumerate (
277
+ zip (outputs_no_cg_no_overlap , outputs_with_cg_no_overlap ,
278
+ outputs_with_cg_with_overlap )):
279
+
280
+ assert (
281
+ no_cg_no_overlap .outputs [0 ].text ==
282
+ with_cg_no_overlap .outputs [0 ].text
283
+ ), f"Prompt { i } : no CG no overlap generated text != with CG no overlap generated text"
284
+ assert (
285
+ with_cg_no_overlap .outputs [0 ].text ==
286
+ with_cg_with_overlap .outputs [0 ].text
287
+ ), f"Prompt { i } : with CG no overlap generated text != with CG with overlap generated text"
285
288
286
289
# similar to other unittests comparing with / without CG, compare logits of first generation step (2nd generated token)
287
290
torch .testing .assert_close (
288
291
no_cg_no_overlap .outputs [0 ].generation_logits [1 , :],
289
292
with_cg_no_overlap .outputs [0 ].generation_logits [1 , :],
290
293
atol = 0.2 ,
291
- rtol = 0.2 )
294
+ rtol = 0.2 ,
295
+ msg = lambda x :
296
+ f"Prompt { i } : with/without CG (no overlap) logits for first generated step { x } "
297
+ )
292
298
293
299
# compare logprobs of all generated tokens
294
- torch .testing .assert_close (extract_decode_logprobs (no_cg_no_overlap ),
295
- extract_decode_logprobs (with_cg_no_overlap ),
296
- atol = 0.2 ,
297
- rtol = 0.2 )
300
+ torch .testing .assert_close (
301
+ extract_decode_logprobs (no_cg_no_overlap ),
302
+ extract_decode_logprobs (with_cg_no_overlap ),
303
+ atol = 0.2 ,
304
+ rtol = 0.2 ,
305
+ msg = lambda x :
306
+ f"Prompt { i } : with/without CG (no overlap) logprobs for all selected tokens { x } "
307
+ )
298
308
299
309
# overlap scheduler should have no effect on all logits - low tolerance
300
310
torch .testing .assert_close (
301
311
with_cg_no_overlap .outputs [0 ].generation_logits ,
302
312
with_cg_with_overlap .outputs [0 ].generation_logits ,
303
313
atol = 0.05 ,
304
- rtol = 0.05 )
314
+ rtol = 0.05 ,
315
+ msg = lambda x :
316
+ f"Prompt { i } : with/without overlap (no CG) all generation logits { x } "
317
+ )
0 commit comments