Skip to content

[bug] unnecessary batch logits post processor calls #2439

@akhoroshev

Description

@akhoroshev

version

When I build model with paged_context_fmha = true and max_num_tokens = 4096, chunked context is enabled. I see that Executor calls batch_logit_processor more than one time for the first token.

To prove that I'm printing the number of tokens in callback (FusedLogitsProcessor::process is my implementation of callback).

I send request with different input size and set maxTokens to 3.

input_context_size: 18810

[TensorRT-LLM][ERROR] FusedLogitsProcessor::process, beamToken.size() 18810
[TensorRT-LLM][ERROR] FusedLogitsProcessor::process, beamToken.size() 18810
[TensorRT-LLM][ERROR] FusedLogitsProcessor::process, beamToken.size() 18810
[TensorRT-LLM][ERROR] FusedLogitsProcessor::process, beamToken.size() 18810
[TensorRT-LLM][ERROR] FusedLogitsProcessor::process, beamToken.size() 18810
[TensorRT-LLM][ERROR] FusedLogitsProcessor::process, beamToken.size() 18811
[TensorRT-LLM][ERROR] FusedLogitsProcessor::process, beamToken.size() 18812

input_context_size: 15014

[TensorRT-LLM][ERROR] FusedLogitsProcessor::process, beamToken.size() 15014
[TensorRT-LLM][ERROR] FusedLogitsProcessor::process, beamToken.size() 15014
[TensorRT-LLM][ERROR] FusedLogitsProcessor::process, beamToken.size() 15014
[TensorRT-LLM][ERROR] FusedLogitsProcessor::process, beamToken.size() 15014
[TensorRT-LLM][ERROR] FusedLogitsProcessor::process, beamToken.size() 15015
[TensorRT-LLM][ERROR] FusedLogitsProcessor::process, beamToken.size() 15016

input_context_size: 12585

[TensorRT-LLM][ERROR] FusedLogitsProcessor::process, beamToken.size() 12585
[TensorRT-LLM][ERROR] FusedLogitsProcessor::process, beamToken.size() 12585
[TensorRT-LLM][ERROR] FusedLogitsProcessor::process, beamToken.size() 12585
[TensorRT-LLM][ERROR] FusedLogitsProcessor::process, beamToken.size() 12585
[TensorRT-LLM][ERROR] FusedLogitsProcessor::process, beamToken.size() 12586
[TensorRT-LLM][ERROR] FusedLogitsProcessor::process, beamToken.size() 12587

input_context_size: 8176

[TensorRT-LLM][ERROR] FusedLogitsProcessor::process, beamToken.size() 8176
[TensorRT-LLM][ERROR] FusedLogitsProcessor::process, beamToken.size() 8176
[TensorRT-LLM][ERROR] FusedLogitsProcessor::process, beamToken.size() 8177
[TensorRT-LLM][ERROR] FusedLogitsProcessor::process, beamToken.size() 8178

You can see that first token logit callback is repeated ceil(input_context_size / max_num_tokens) times. In fact, the logits for calls to ceil(input_context_size/max_num_tokens) - 1 are ignored (sampling layers are not called) and Executor returns exactly 3 tokens (as expected). But it's very strange to run a logit processor for "garbage" logits.

Metadata

Metadata

Assignees

Labels

triagedIssue has been triaged by maintainers

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions