Skip to content

Commit 6d5da9f

Browse files
authored
[https://nvbugs/5404046][fix] Fix Nemotron-H flaky CUDA graph / overlap scheduler test (#6485)
Signed-off-by: Tomer Asida <[email protected]>
1 parent 0c42f54 commit 6d5da9f

File tree

1 file changed

+35
-22
lines changed

1 file changed

+35
-22
lines changed

tests/unittest/_torch/modeling/test_modeling_nemotron_h.py

Lines changed: 35 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import pytest
21
import torch
32
from utils.llm_data import llm_models_root
43
from utils.util import skip_gpu_memory_less_than
@@ -238,15 +237,15 @@ def test_nemotron_h_correctness():
238237
nemotron_h.shutdown()
239238

240239

241-
@pytest.mark.skip(reason="https://nvbugs/5404046")
242240
def test_nemotron_h_cuda_graph_overlap_scheduler():
243241
prompts = [
244-
"Tell me something I don't know about the future of AI",
245-
"The president of the United States is",
246-
"The capital of France is",
247-
"Hello, this is a beautiful day and I'm eager to start my day and",
242+
"The sky is blue because",
243+
"The sum of two and two is",
244+
"The largest mammal is the",
245+
"The chemical symbol for water is",
248246
]
249-
sampling_config = SamplingParams(max_tokens=12,
247+
248+
sampling_config = SamplingParams(max_tokens=10,
250249
temperature=0.0,
251250
return_generation_logits=True)
252251

@@ -273,32 +272,46 @@ def test_nemotron_h_cuda_graph_overlap_scheduler():
273272
prompts, sampling_params=sampling_config, use_tqdm=True)
274273

275274
# Verify outputs are consistent
276-
for (no_cg_no_overlap, with_cg_no_overlap,
277-
with_cg_with_overlap) in zip(outputs_no_cg_no_overlap,
278-
outputs_with_cg_no_overlap,
279-
outputs_with_cg_with_overlap):
280-
281-
assert (no_cg_no_overlap.outputs[0].text ==
282-
with_cg_no_overlap.outputs[0].text)
283-
assert (with_cg_no_overlap.outputs[0].text ==
284-
with_cg_with_overlap.outputs[0].text)
275+
for i, (no_cg_no_overlap, with_cg_no_overlap,
276+
with_cg_with_overlap) in enumerate(
277+
zip(outputs_no_cg_no_overlap, outputs_with_cg_no_overlap,
278+
outputs_with_cg_with_overlap)):
279+
280+
assert (
281+
no_cg_no_overlap.outputs[0].text ==
282+
with_cg_no_overlap.outputs[0].text
283+
), f"Prompt {i}: no CG no overlap generated text != with CG no overlap generated text"
284+
assert (
285+
with_cg_no_overlap.outputs[0].text ==
286+
with_cg_with_overlap.outputs[0].text
287+
), f"Prompt {i}: with CG no overlap generated text != with CG with overlap generated text"
285288

286289
# similar to other unittests comparing with / without CG, compare logits of first generation step (2nd generated token)
287290
torch.testing.assert_close(
288291
no_cg_no_overlap.outputs[0].generation_logits[1, :],
289292
with_cg_no_overlap.outputs[0].generation_logits[1, :],
290293
atol=0.2,
291-
rtol=0.2)
294+
rtol=0.2,
295+
msg=lambda x:
296+
f"Prompt {i}: with/without CG (no overlap) logits for first generated step {x}"
297+
)
292298

293299
# compare logprobs of all generated tokens
294-
torch.testing.assert_close(extract_decode_logprobs(no_cg_no_overlap),
295-
extract_decode_logprobs(with_cg_no_overlap),
296-
atol=0.2,
297-
rtol=0.2)
300+
torch.testing.assert_close(
301+
extract_decode_logprobs(no_cg_no_overlap),
302+
extract_decode_logprobs(with_cg_no_overlap),
303+
atol=0.2,
304+
rtol=0.2,
305+
msg=lambda x:
306+
f"Prompt {i}: with/without CG (no overlap) logprobs for all selected tokens {x}"
307+
)
298308

299309
# overlap scheduler should have no effect on all logits - low tolerance
300310
torch.testing.assert_close(
301311
with_cg_no_overlap.outputs[0].generation_logits,
302312
with_cg_with_overlap.outputs[0].generation_logits,
303313
atol=0.05,
304-
rtol=0.05)
314+
rtol=0.05,
315+
msg=lambda x:
316+
f"Prompt {i}: with/without overlap (no CG) all generation logits {x}"
317+
)

0 commit comments

Comments
 (0)