Skip to content

Commit a9d76c8

Browse files
committed
Revert "[None][feat] Optimize CUDA graph memory usage for spec decode cases (#6718)"
This reverts commit 8df7a26.
1 parent 8861b56 commit a9d76c8

File tree

2 files changed

+4
-12
lines changed

2 files changed

+4
-12
lines changed

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -726,11 +726,8 @@ def disable_optimization(backend: Backend):
726726
# For non-draft model, we also capture the CUDA graph instance for draft length 0,
727727
# so that when we disable spec decode at runtime, we can still run the captured graph.
728728
# Note that for one engine mode, we are not able to turn off spec decode at runtime.
729-
if (not self.is_draft_model and self.max_draft_len > 0
730-
and not self.spec_config.spec_dec_mode.use_one_engine()
731-
# Assume that speculation is always on if the user didn't give us a max_concurrency
732-
# value. This will save on memory.
733-
and self.spec_config.max_concurrency is not None):
729+
if not self.is_draft_model and self.max_draft_len > 0 and not self.spec_config.spec_dec_mode.use_one_engine(
730+
):
734731
draft_lengths.append(0)
735732

736733
for bs in cuda_graph_batch_sizes:

tensorrt_llm/_torch/speculative/drafter.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC, abstractmethod
2-
from typing import List, Optional, final
2+
from typing import List, Optional
33

44
from ..pyexecutor.llm_request import LlmRequest
55
from ..pyexecutor.resource_manager import ResourceManager
@@ -26,13 +26,8 @@ def prepare_draft_tokens(
2626
"""
2727
raise NotImplementedError
2828

29-
@final
3029
def should_use_spec_decode(self, requests: List[LlmRequest]) -> bool:
31-
"""
32-
You probably don't want to override this. ModelEngine
33-
assumes that speculation is always on if max_concurrency
34-
is not specified by the user's spec config.
35-
"""
30+
"""Check if spec decode should be used for the current iteration."""
3631
if self.max_concurrency is not None:
3732
return len(requests) <= self.max_concurrency
3833
return True

0 commit comments

Comments
 (0)