Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions cpp/include/tensorrt_llm/runtime/eagleBuffers.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ class EagleBuffers
//! [maxBatchSize, maxNumPaths, maxPathLen]
//! or [numSequences, maxNumPaths, maxPathLen]
TensorPtr draftPaths;
//! [maxBatchSize, maxNumPaths, maxPathLen]
//! or [numSequences, maxNumPaths, maxPathLen]
TensorPtr draftPathsHost;
//! [maxBatchSize] or [numGenSequences]
TensorPtr specDecodingGenerationLengths;
//! [maxBatchSize] or [numGenSequences]
Expand Down
10 changes: 5 additions & 5 deletions cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -425,25 +425,25 @@ void CreateNewDecoderRequests::newRequestEagle(SizeType32 batchIdx, runtime::dec
modelConfig.getSpeculativeDecodingModulePtr());
std::optional<executor::EagleChoices> eagleChoicesOpt;

TensorPtr draftPathsSlice = ITensor::slice(jointDecodingOutput.eagleBuffers->draftPaths, batchIdx, 1);

if (request.eagleConfig)
{
eagleChoicesOpt = request.eagleConfig->getEagleChoices();
}

if (!request.eagleConfig || !request.eagleConfig->useDynamicTree())
{
TensorPtr draftPathsHostSlice = ITensor::slice(jointDecodingOutput.eagleBuffers->draftPathsHost, batchIdx, 1);
TensorPtr draftPathsSlice = ITensor::slice(jointDecodingOutput.eagleBuffers->draftPaths, batchIdx, 1);

// eagleConfig is nullptr or Eagle-1
std::vector<SizeType32> topKs;
TensorPtr draftPathsHost = manager.pinnedPool(draftPathsSlice->getShape(), nvinfer1::DataType::kINT32);
auto const depth = utils::initTensorsFromChoices(modelConfig.getSpeculativeDecodingModule(),
eagleChoicesOpt.value_or(eagleModule->getDefaultEagleChoices()), topKs, nullptr, nullptr, nullptr,
draftPathsHost, nullptr, {eagleModule->getMaxNonLeafNodesPerLayer()});
draftPathsHostSlice, nullptr, {eagleModule->getMaxNonLeafNodesPerLayer()});
TLLM_CHECK_WITH_INFO(depth == modelConfig.getSpeculativeDecodingModule().getMaxDraftPathLen(),
"EAGLE-1 requires Eagle-tree depth being equal to the the number of build-time EAGLE layers.");

manager.copy(*draftPathsHost, *draftPathsSlice);
manager.copy(*draftPathsHostSlice, *draftPathsSlice);
}

TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
Expand Down
20 changes: 15 additions & 5 deletions cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2162,7 +2162,6 @@ void TrtGptModelInflightBatching::updateRequests(ScheduledRequests const& schedu
}
auto const reqBeamWidth = llmReq->getBeamWidthByIter(true);
auto const seqSlot = llmReq->mSeqSlot.value();
auto const numGeneratedTokens = llmReq->getNumDraftTokens() + 1;
auto const currentNumOfTokens = llmReq->getMaxBeamNumTokens();

// Save the accepted token logits from target model
Expand All @@ -2183,13 +2182,24 @@ void TrtGptModelInflightBatching::updateRequests(ScheduledRequests const& schedu
std::vector<SizeType32> numNewTokens(reqBeamWidth);
std::vector<SizeType32> numDroppedTokens(reqBeamWidth);

// numGeneratedTokens is the number of tokens generated by the decoder.
// Some tokens might be dropped due to end token or rejected draft tokens.
auto const numGeneratedTokens = llmReq->getNumDraftTokens() + 1;

for (SizeType32 beam = 0; beam < reqBeamWidth; ++beam)
{
// Sequence length is only advanced for accepted tokens.
auto const seqLen = sequenceLengthsHostData[seqSlot * mOperatingBeamWidth + beam];
// The content of newOutputTokens might not be accurate in the case where
// the sequence has finished early due to end token, so only add new tokens
// to llmReq if decoder seq length is greater than current number of tokens
numNewTokens[beam] = std::min(numGeneratedTokens, seqLen - llmReq->getNumTokens(beam));
// Actual number of tokens that should be added to the request.
auto const numNewOutputTokens = seqLen - llmReq->getNumTokens(beam);
if (reqBeamWidth == 1)
{
TLLM_CHECK_WITH_INFO(numGeneratedTokens >= numNewOutputTokens,
"numNewOutputTokens must not be greater than numGeneratedTokens: "
"numGeneratedTokens %d < numNewOutputTokens %d",
numGeneratedTokens, numNewOutputTokens);
}
numNewTokens[beam] = std::min(numGeneratedTokens, numNewOutputTokens);
numDroppedTokens[beam] = numGeneratedTokens - numNewTokens[beam];
for (SizeType32 step = 0; step < numNewTokens[beam]; ++step)
{
Expand Down
8 changes: 4 additions & 4 deletions cpp/tensorrt_llm/kernels/speculativeDecoding/common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ __global__ void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdT
TokenIdType const* targetIds, SizeType32* sequenceLengths, SizeType32* acceptedLengths,
FinishedState* finishedFinal, SizeType32 const* batchSlots, SizeType32 const* paths, TokenIdType const* endIds,
T const** medusaLogits, T const** logitsPtrs, SizeType32* curTokensPerStep, SizeType32 const* targetTokensPerStep,
SizeType32* bestPathIds, SizeType32 batchSize, SizeType32 vocabSize, SizeType32 maxBatchSize, SizeType32 maxSeqLen,
SizeType32 maxDraftPathLen, SizeType32 maxDecodingTokens)
SizeType32* bestPathIds, SizeType32 vocabSize, SizeType32 maxSeqLen, SizeType32 maxDraftPathLen,
SizeType32 maxDecodingTokens)
{
auto const batchIdx = static_cast<SizeType32>(blockIdx.x);
auto const batchSlot = batchSlots == nullptr ? batchIdx : batchSlots[batchIdx];
Expand Down Expand Up @@ -276,8 +276,8 @@ void acceptDraftTokensByIdsWithPaths(AcceptDraftTokensByIdsWithPathsParams<T> co
acceptDraftTokensByIdsWithPaths<T, BLOCK_SIZE><<<grid, block, 0, params.stream>>>(params.outputIds, params.draftIds,
params.targetIds, params.sequenceLengths, params.acceptedLengths, params.finishedFinal, params.batchSlots,
params.paths, params.endIds, params.medusaLogits, params.logitsPtrs, params.curTokensPerStep,
params.targetTokensPerStep, params.bestPathIds, params.batchSize, params.vocabSize, params.maxBatchSize,
params.maxSeqLen, params.maxDraftPathLen, params.maxDecodingTokens);
params.targetTokensPerStep, params.bestPathIds, params.vocabSize, params.maxSeqLen, params.maxDraftPathLen,
params.maxDecodingTokens);
}

template void acceptDraftTokensByIdsWithPaths(AcceptDraftTokensByIdsWithPathsParams<float> const& params);
Expand Down
Loading