Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,15 @@ void initBindings(pybind11::module_& m)
.def_property_readonly("is_generation_in_progress_state", &GenLlmReq::isGenerationInProgressState)
.def_property_readonly("is_disagg_context_transmission_state", &GenLlmReq::isDisaggContextTransmissionState)
.def_property_readonly("is_disagg_context_complete_state", &GenLlmReq::isDisaggContextCompleteState)
.def_property_readonly("stage", &GenLlmReq::getRequestStage)
.def_property_readonly("kv_cache_transfer_time_ms", &GenLlmReq::getKvCacheTransferTimeMS)
.def_property_readonly("kv_cache_size", &GenLlmReq::getKvCacheSize)
.def_property_readonly("avg_decoded_tokens_per_iter", &GenLlmReq::getAvgDecodedTokensPerIter)
.def_property_readonly("alloc_total_blocks", &GenLlmReq::getAllocTotalBlocksPerRequest)
.def_property_readonly("alloc_new_blocks", &GenLlmReq::getAllocNewBlocksPerRequest)
.def_property_readonly("reused_blocks", &GenLlmReq::getReusedBlocksPerRequest)
.def_property_readonly("missed_blocks", &GenLlmReq::getMissedBlocksPerRequest)
.def_property_readonly("kv_cache_hit_rate", &GenLlmReq::getKVCacheHitRatePerRequest)
.def_property_readonly("llm_request_type", &GenLlmReq::getLlmRequestType)
.def_property_readonly("position_ids",
[](GenLlmReq& self)
Expand Down
3 changes: 3 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ class PyTorchConfig:
kv_cache_dtype: str = "auto"
use_kv_cache: bool = True
enable_iter_perf_stats: bool = False
# If true, enables per request stats per iteration
# Must also set enable_iter_perf_stats to true to get request stats
enable_iter_req_stats: bool = False
print_iter_log: bool = False

torch_compile_enabled: bool = False
Expand Down
99 changes: 87 additions & 12 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@

from tensorrt_llm._utils import (global_mpi_rank, is_trace_enabled, nvtx_range,
trace_func)
from tensorrt_llm.bindings.executor import (FinishReason, InflightBatchingStats,
from tensorrt_llm.bindings.executor import (DisServingRequestStats,
FinishReason, InflightBatchingStats,
IterationStats, KvCacheStats,
RequestStage, RequestStats,
RequestType, StaticBatchingStats)
from tensorrt_llm.bindings.internal.batch_manager import ReqIdsSet
from tensorrt_llm.bindings.internal.batch_manager import (LlmRequestType,
ReqIdsSet)
from tensorrt_llm.logger import logger

from ..distributed import Distributed
Expand Down Expand Up @@ -196,6 +199,7 @@ def __init__(self,
self.max_draft_tokens = max_draft_tokens
self.print_log = model_engine.pytorch_backend_config.print_iter_log
self.enable_iter_perf_stats = model_engine.pytorch_backend_config.enable_iter_perf_stats
self.enable_iter_req_stats = model_engine.pytorch_backend_config.enable_iter_req_stats
self.num_fetch_requests_cur_rank = 0
self.num_fetch_requests = 0
self.shutdown_event = threading.Event()
Expand Down Expand Up @@ -373,10 +377,10 @@ def get_latest_iteration_stats(self):
if self.enable_iter_perf_stats == False:
return []

latest_stats = tuple()
latest_stats = (IterationStats(), None)
try:
self.stats_lock.acquire()
latest_stats = tuple(self.stats)
latest_stats = self.stats
self.stats = []
finally:
self.stats_lock.release()
Expand Down Expand Up @@ -510,8 +514,63 @@ def _get_init_iter_stats(self, num_new_active_requests,
stats.static_batching_stats = StaticBatchingStats()
return stats

def _populate_req_stats(
self, finished_requests: List[LlmRequest],
active_requests: List[LlmRequest],
scheduled_requests: ScheduledRequests
) -> Optional[List[RequestStats]]:

def get_req_stats(req: LlmRequest) -> RequestStats:
req_stat = RequestStats()
req_stat.id = req.request_id
req_stat.context_prefill_position = req.context_current_position
req_stat.num_generated_tokens = req.max_beam_num_tokens - req.orig_prompt_len
req_stat.avg_num_decoded_tokens_per_iter = req.avg_decoded_tokens_per_iter
req_stat.alloc_total_blocks_per_request = req.alloc_total_blocks
req_stat.alloc_new_blocks_per_request = req.alloc_new_blocks
req_stat.reused_blocks_per_request = req.reused_blocks
req_stat.missed_blocks_per_request = req.missed_blocks
req_stat.kv_cache_hit_rate_per_request = req.kv_cache_hit_rate
req_stat.scheduled = req in scheduled_requests.context_requests or req in scheduled_requests.generation_requests
if req.llm_request_type == LlmRequestType.LLMREQUEST_TYPE_CONTEXT_ONLY or req.llm_request_type == LlmRequestType.LLMREQUEST_TYPE_GENERATION_ONLY:
req_stat.dis_serving_stats = DisServingRequestStats()
req_stat.dis_serving_stats.kv_cache_transfer_ms = req.kv_cache_transfer_time_ms
req_stat.dis_serving_stats.kv_cache_size = req.kv_cache_size
return req_stat

def get_queued_req_stats(req: LlmRequest) -> RequestStats:
req_stat = RequestStats()
req_stat.id = req.request_id
req_stat.context_prefill_position = 0
req_stat.num_generated_tokens = 0
req_stat.avg_num_decoded_tokens_per_iter = 0
req_stat.alloc_total_blocks_per_request = 0
req_stat.alloc_new_blocks_per_request = 0
req_stat.reused_blocks_per_request = 0
req_stat.missed_blocks_per_request = 0
req_stat.kv_cache_hit_rate_per_request = 0
return req_stat

req_stats = []
for req in active_requests:
req_stat = get_req_stats(req)
req_stat.stage = req.stage
req_stats.append(req_stat)

for req in list(self.request_queue.queue):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pcastonguay I got bug with function. The input of get_queued_req_stats is a tuple, not LlmRequest. self.request_queue.put((self.next_req_id, request))

req_stat = get_queued_req_stats(req)
req.stage = RequestStage.QUEUED
req_stats.append(req_stat)

for req in finished_requests:
req_stat = get_req_stats(req)
req_stat.stage = RequestStage.GENERATION_COMPLETE
req_stats.append(req_stat)

return req_stats

def _update_iter_stats(self, stats, iter_latency_ms, num_completed_requests,
scheduled_batch):
scheduled_batch) -> IterationStats:
stats.iter_latency_ms = iter_latency_ms

stats.num_queued_requests = self.request_queue.qsize()
Expand Down Expand Up @@ -554,23 +613,34 @@ def _update_iter_stats(self, stats, iter_latency_ms, num_completed_requests,
stats.inflight_batching_stats.micro_batch_id = 0
return stats

def _append_iter_stats(self, stats):
def _append_iter_stats(self,
stats: IterationStats,
req_stats: Optional[List[RequestStats]] = None):

try:
self.stats_lock.acquire()
self.stats.append(stats)
self.stats.append((stats, req_stats))
finally:
self.stats_lock.release()

def _process_iter_stats(self, finished_requests: list[LlmRequest],
active_requests: List[LlmRequest],
batch_state: BatchState):
iter_end_time = time.time()
iter_latency_ms = iter_end_time - batch_state.iter_start_time
if batch_state.iter_stats is None:
return

req_stats = self._populate_req_stats(
finished_requests, active_requests,
batch_state.decoder_state.scheduled_requests) if (
self.enable_iter_req_stats
and self.enable_iter_perf_stats) else None

self._append_iter_stats(
self._update_iter_stats(
batch_state.iter_stats, iter_latency_ms, len(finished_requests),
batch_state.decoder_state.scheduled_requests))
batch_state.decoder_state.scheduled_requests), req_stats)

def _executor_loop_cleanup(self):
with self.response_cv:
Expand Down Expand Up @@ -677,7 +747,9 @@ def _executor_loop_pp(self):
self._gather_dp_requests_num()

if self.enable_iter_perf_stats and previous_batch is not None:
self._process_iter_stats(finished_requests, previous_batch)
self._process_iter_stats(finished_requests,
self.active_requests,
previous_batch)
self._executor_loop_cleanup()

def _executor_loop_pp_overlap(self):
Expand Down Expand Up @@ -815,7 +887,9 @@ def _executor_loop_pp_overlap(self):
self._gather_dp_requests_num()

if self.enable_iter_perf_stats and previous_batch is not None:
self._process_iter_stats(finished_requests, previous_batch)
self._process_iter_stats(finished_requests,
self.active_requests,
previous_batch)
self._executor_loop_cleanup()

def _executor_loop(self):
Expand Down Expand Up @@ -921,7 +995,7 @@ def _executor_loop(self):
iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[
'num_ctx_tokens']
self._process_iter_stats(
finished_requests,
finished_requests, self.active_requests,
BatchState(decoder_state=DecoderState(
scheduled_requests=scheduled_batch),
iter_stats=iter_stats,
Expand Down Expand Up @@ -1099,7 +1173,8 @@ def _process_previous_batch(self):
self._add_kv_cache_events()

if self.enable_iter_perf_stats:
self._process_iter_stats(finished_requests, self.previous_batch)
self._process_iter_stats(finished_requests, self.active_requests,
self.previous_batch)

@nvtx_range("_forward_step_inter_pp")
def _forward_step_inter_pp(self, scheduled_batch) -> DecoderState:
Expand Down
31 changes: 28 additions & 3 deletions tensorrt_llm/executor/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,9 +272,34 @@ def _iteration_result_task(self, it_result_queue: IterationResultQueue,
return True # success

def dispatch_stats_task(self) -> bool:
return self._iteration_result_task(
self.stats_queues, self.engine.get_latest_iteration_stats,
self._iter_stats_result, lambda x: x.to_json_str())

# Define a Callable to join iteration and request stats
def stats_serializer(
stats: Tuple[tllm.IterationStats, tllm.RequestStats]) -> str:
iteration_stats, req_stats = stats
stats_dict = json.loads(iteration_stats.to_json_str())

if req_stats is not None and len(req_stats) > 0:
stats_dict["requestStats"] = []
for req_stat in req_stats:
stats_dict["requestStats"].append(
json.loads(req_stat.to_json_str()))

# Convert back to JSON string
return json.dumps(stats_dict)

def get_stats():
if isinstance(self.engine, tllm.Executor):
iter_stats = self.engine.get_latest_iteration_stats()
#TODO: Support req stats with TRT engine
# This would require ensuring iter and req stats have same size
return [(iter_stat, None) for iter_stat in iter_stats]
else:
return self.engine.get_latest_iteration_stats()

return self._iteration_result_task(self.stats_queues, get_stats,
self._iter_stats_result,
stats_serializer)

def dispatch_kv_cache_events_task(self) -> bool:
if isinstance(self.engine, tllm.Executor):
Expand Down
Loading