Skip to content

Commit 8c7244c

Browse files
Export NaNs in logits to scheduler_stats if output is corrupted (vllm-project#18777)
Summary: Pull Request resolved: vllm-project#18777 Signed-off-by: Vlad Mihailescu <[email protected]> Report nan in logits in scheduler_stats. This can be used later to bump Phrometeus counter but for now this is required so we can export it in our internal counter infra. This counter is used to identify bad hosts or bad GPUs which cause NaNs in logits during model forward passes. It's a common metric we expose internally. Reviewed By: Adolfo-Karim Differential Revision: D75423285 Signed-off-by: Vlad Mihailescu <[email protected]>
1 parent ee9a153 commit 8c7244c

File tree

7 files changed

+85
-2
lines changed

7 files changed

+85
-2
lines changed

tests/v1/worker/test_gpu_model_runner.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import random
55

66
import pytest
7+
import torch
78

89
from vllm.attention import Attention
910
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
@@ -277,6 +278,36 @@ def test_update_states_request_resumed(model_runner):
277278
assert _is_req_state_block_table_match(model_runner, req_id)
278279

279280

281+
def test_get_nans_in_logits(model_runner):
282+
req_ids = ("req_0", "req_1")
283+
284+
scheduler_output = _schedule_new_request(*req_ids)
285+
model_runner._update_states(scheduler_output)
286+
287+
logits = torch.tensor([
288+
[1.0, 2.0, 3.0],
289+
[3.0, 2.0, 1.0],
290+
], device=DEVICE)
291+
result = model_runner._get_nans_in_logits(logits)
292+
assert result == {"req_0": 0, "req_1": 0}
293+
294+
logits = torch.tensor([
295+
[1.0, float('nan'), 3.0],
296+
[4.0, float('nan'), float('nan')],
297+
],
298+
device=DEVICE)
299+
result = model_runner._get_nans_in_logits(logits)
300+
assert result == {"req_0": 1, "req_1": 2}
301+
302+
logits = torch.tensor([
303+
[1.0, 2.0, 3.0],
304+
[4.0, float('nan'), float('nan')],
305+
],
306+
device=DEVICE)
307+
result = model_runner._get_nans_in_logits(logits)
308+
assert result == {"req_0": 0, "req_1": 2}
309+
310+
280311
def test_update_states_no_changes(model_runner):
281312
req_id = "req_0"
282313

vllm/envs.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@
130130
VLLM_SLEEP_WHEN_IDLE: bool = False
131131
VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16
132132
VLLM_KV_CACHE_LAYOUT: Optional[str] = None
133+
VLLM_COMPUTE_NANS_IN_LOGITS: bool = False
133134

134135

135136
def get_default_cache_root():
@@ -897,7 +898,13 @@ def get_vllm_port() -> Optional[int]:
897898
# leave the layout choice to the backend. Mind that backends may only
898899
# implement and support a subset of all possible layouts.
899900
"VLLM_KV_CACHE_LAYOUT":
900-
lambda: os.getenv("VLLM_KV_CACHE_LAYOUT", None)
901+
lambda: os.getenv("VLLM_KV_CACHE_LAYOUT", None),
902+
903+
# Enable checking whether the generated logits contain NaNs,
904+
# indicating corrupted output. Useful for debugging low level bugs
905+
# or bad hardware but it may add compute overhead.
906+
"VLLM_COMPUTE_NANS_IN_LOGITS":
907+
lambda: bool(int(os.getenv("VLLM_COMPUTE_NANS_IN_LOGITS", "0"))),
901908
}
902909

903910
# --8<-- [end:env-vars-definition]

vllm/v1/core/sched/scheduler.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -717,6 +717,7 @@ def update_from_output(
717717
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
718718
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
719719
pooler_outputs = model_runner_output.pooler_output
720+
num_nans_in_logits = model_runner_output.num_nans_in_logits
720721

721722
new_running: list[Request] = []
722723
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
@@ -810,6 +811,10 @@ def update_from_output(
810811
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
811812
req_id, new_token_ids)
812813

814+
# spec_token_ids comes from the model runner output
815+
if num_nans_in_logits is not None and req_id in num_nans_in_logits:
816+
request.num_nans_in_logits = num_nans_in_logits[req_id]
817+
813818
# Add newly generated spec token ids to the request.
814819
if spec_token_ids is not None:
815820
if self.structured_output_manager.should_advance(request):
@@ -972,6 +977,8 @@ def make_stats(
972977
kv_cache_usage=self.kv_cache_manager.usage,
973978
prefix_cache_stats=prefix_cache_stats,
974979
spec_decoding_stats=spec_decoding_stats,
980+
num_corrupted_reqs=sum(req.is_output_corrupted
981+
for req in self.running),
975982
)
976983

977984
def make_spec_decoding_stats(

vllm/v1/metrics/stats.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ class SchedulerStats:
4040

4141
spec_decoding_stats: Optional[SpecDecodingStats] = None
4242

43+
num_corrupted_reqs: int = 0
44+
4345

4446
@dataclass
4547
class LoRAStats:

vllm/v1/outputs.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,9 @@ class ModelRunnerOutput:
108108
finished_sending: Optional[set[str]] = None
109109
finished_recving: Optional[set[str]] = None
110110

111+
# req_id -> num_nans_in_logits
112+
num_nans_in_logits: Optional[dict[str, int]] = None
113+
111114

112115
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[],
113116
req_id_to_index={},
@@ -117,4 +120,5 @@ class ModelRunnerOutput:
117120
prompt_logprobs_dict={},
118121
pooler_output=[],
119122
finished_sending=None,
120-
finished_recving=None)
123+
finished_recving=None,
124+
num_nans_in_logits=None)

vllm/v1/request.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ def __init__(
9797
# The number of tokens with prefix cache hits.
9898
self.num_cached_tokens = -1
9999

100+
# The number of NaNs in logits. A value greater than 0
101+
# indicates that the output is corrupted
102+
self.num_nans_in_logits = 0
103+
100104
@classmethod
101105
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
102106
if request.mm_inputs is not None:
@@ -132,6 +136,10 @@ def append_output_token_ids(
132136
self._output_token_ids.extend(token_ids)
133137
self._all_token_ids.extend(token_ids)
134138

139+
@property
140+
def is_output_corrupted(self) -> bool:
141+
return self.num_nans_in_logits > 0
142+
135143
@property
136144
def num_tokens(self) -> int:
137145
return len(self._all_token_ids)

vllm/v1/worker/gpu_model_runner.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1431,6 +1431,10 @@ def execute_model(
14311431
)
14321432
sampler_output.sampled_token_ids = output_token_ids
14331433

1434+
num_nans_in_logits = {}
1435+
if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
1436+
num_nans_in_logits = self._get_nans_in_logits(logits)
1437+
14341438
# TODO(woosuk): The following loop can be slow since it iterates over
14351439
# the requests one by one. Optimize.
14361440
discard_sampled_tokens_req_indices = []
@@ -1601,6 +1605,7 @@ def execute_model(
16011605
pooler_output=[],
16021606
finished_sending=finished_sending,
16031607
finished_recving=finished_recving,
1608+
num_nans_in_logits=num_nans_in_logits,
16041609
)
16051610

16061611
def kv_connector_no_forward(
@@ -1826,6 +1831,25 @@ def _get_prompt_logprobs_dict(
18261831

18271832
return prompt_logprobs_dict
18281833

1834+
def _get_nans_in_logits(
1835+
self,
1836+
logits: torch.Tensor,
1837+
) -> dict[str, int]:
1838+
try:
1839+
num_nans_in_logits = {}
1840+
num_nans_for_index = None
1841+
if logits is not None:
1842+
num_nans_for_index = logits.isnan().sum(dim=-1).cpu().numpy()
1843+
for req_id in self.input_batch.req_ids:
1844+
req_index = self.input_batch.req_id_to_index[req_id]
1845+
num_nans_in_logits[req_id] = (
1846+
int(num_nans_for_index[req_index])
1847+
if logits is not None and num_nans_for_index is not None
1848+
and req_index < logits.shape[0] else 0)
1849+
return num_nans_in_logits
1850+
except IndexError:
1851+
return {}
1852+
18291853
@contextmanager
18301854
def maybe_randomize_inputs(self, input_ids: torch.Tensor):
18311855
"""

0 commit comments

Comments
 (0)