Skip to content

[None][feat] move kv cache measure into transfer session #6633

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Aug 8, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ void CacheFormatter::format(TransferSession& session)
}
double cacheTransferTime
= std::max(0.0, std::chrono::duration<double, std::milli>(endTime - startTime).count());
kvCacheMeasureHelper.appendKVCacheTransfer(llmRequest.mRequestId, delay, cacheTransferTime, size);
session.appendMeasure(delay, cacheTransferTime, size);
};

if (connections.size() > 1)
Expand Down Expand Up @@ -712,7 +712,7 @@ void CacheFormatter::unformat(TransferSession& session)
}
double cacheTransferTime
= std::max(0.0, std::chrono::duration<double, std::milli>(endTime - startTime).count());
kvCacheMeasureHelper.appendKVCacheTransfer(ctxReqId, delay, cacheTransferTime, size);
session.appendMeasure(delay, cacheTransferTime, size);
};
if (pickUpConnections.size() > 1)
{
Expand Down
9 changes: 0 additions & 9 deletions cpp/tensorrt_llm/batch_manager/cacheFormatter.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,6 @@ class BaseCacheFormatter

/// @brief Destructor.
virtual ~BaseCacheFormatter() = default;

// TODO: better way for context/generation tagging
void markAsSender(bool isSender)
{
kvCacheMeasureHelper.markAsSender(isSender);
}

protected:
KvCacheMeasureHelper kvCacheMeasureHelper{common::getEnvKVCacheTransferOutputPath()};
};

// Simple cache block copy. Because it does not involve data splitting or merging, it performs best when the
Expand Down
37 changes: 37 additions & 0 deletions cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,43 @@ std::size_t RequestInfo::serializedSize(RequestInfo const& requestInfo)
return totalSize;
}

void TransferSession::appendMeasure(double delay, double duration, size_t size)
{
if (!mRecordMeasure)
{
return;
}
auto bandwidth = size * 8 / (duration / 1000) / 1e9; // byte, ms => Gbps
mMeasures.emplace_back(Measure{delay, duration, bandwidth});
}

void TransferSession::exportMeasure(std::ofstream& outFile, bool isContext) const
{
if (mMeasures.empty())
{
return;
}
// write header if not exist
if (outFile.tellp() == 0)
{
outFile << "RequestID";
for (size_t i = 0; i < mMeasures.size(); i++)
{
outFile << ",Delay(ms),Duration(ms),Bandwidth(Gbps)";
}
outFile << '\n';
}
// write measures
TLLM_CHECK(isContext || mRequest->getContextPhaseParams().has_value());
auto reqId = isContext ? mRequest->mRequestId : mRequest->getContextPhaseParams().value().getReqId();
outFile << reqId;
for (auto const& measure : mMeasures)
{
outFile << "," << measure.delay << "," << measure.duration << "," << measure.bandwidth;
}
outFile << '\n' << std::flush;
}

class DataResponder::Impl
{
public:
Expand Down
92 changes: 16 additions & 76 deletions cpp/tensorrt_llm/batch_manager/dataTransceiver.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,23 @@ class RequestInfo
class TransferSession
{
public:
struct Measure
{
double delay; // from last token (ctx) or arrival time (gen), in ms
double duration; // in ms
double bandwidth; // in Gbps
};

TransferSession(std::vector<Connection const*> connections, DataContext dataContext,
executor::DataTransceiverState const& selfState, executor::DataTransceiverState otherState,
runtime::BufferManager const& bufferManager, LlmRequest const* llmRequest = nullptr)
runtime::BufferManager const& bufferManager, LlmRequest const* llmRequest = nullptr, bool recordMeasure = false)
: mConnections(std::move(connections))
, mDataContext(dataContext)
, mSelfState(&selfState)
, mOtherState(std::move(otherState))
, mBufferManager(&bufferManager)
, mRequest(llmRequest)
, mRecordMeasure(recordMeasure)
{
TLLM_CHECK(!mConnections.empty());
}
Expand Down Expand Up @@ -163,13 +171,20 @@ class TransferSession
mRequest = &llmRequest;
}

void appendMeasure(double delay, double duration, size_t size);

// TODO: 1. use global id instead of context request id; 2. export to llm metrics instead of file
void exportMeasure(std::ofstream& outFile, bool isContext) const;

private:
std::vector<Connection const*> mConnections;
DataContext mDataContext;
executor::DataTransceiverState const* mSelfState; // stored in DataRequester/DataResponder
executor::DataTransceiverState mOtherState;
runtime::BufferManager const* mBufferManager;
LlmRequest const* mRequest;
bool mRecordMeasure;
std::vector<Measure> mMeasures;
};

// Operators required for data transmission in specific communication protocols.
Expand Down Expand Up @@ -266,79 +281,4 @@ class DataRequester
std::unique_ptr<Impl> mImpl;
};

class KvCacheMeasureHelper
{
public:
struct Measure
{
double delay; // from last token (ctx) or arrival time (gen), in ms
double duration; // in ms
double bandwidth; // in Gbps
};

KvCacheMeasureHelper(std::string output_path)
: mOutputPath(std::move(output_path))
{
}

void markAsSender(bool isSender)
{
mIsSender = isSender;
}

void appendKVCacheTransfer(LlmRequest::RequestIdType requestId, double delay, double duration, size_t size)
{
auto bandwidth = size * 8 / (duration / 1000) / 1e9;
if (mOutputPath.empty())
{
return;
}

std::lock_guard<std::mutex> lock(mMutex);
mRequestKVCacheTranfserMeasure[requestId].emplace_back(Measure{delay, duration, bandwidth});
}

~KvCacheMeasureHelper()
{
if (!mRequestKVCacheTranfserMeasure.empty() && !mOutputPath.empty())
{
TLLM_CHECK(mIsSender.has_value());
auto rank = mpi::MpiComm::world().getRank();
std::string outFilePath
= mOutputPath + "rank_" + std::to_string(rank) + "_" + (mIsSender.value() ? "send" : "recv") + ".csv";
std::ofstream outFile(outFilePath);

TLLM_CHECK_WITH_INFO(outFile.is_open(), "Cannot write to file " + outFilePath);

size_t numTransferMeasure = mRequestKVCacheTranfserMeasure.begin()->second.size();

outFile << "RequestID";
for (size_t i = 0; i < numTransferMeasure; i++)
{
outFile << ",Delay(ms),Duration(ms),Bandwidth(Gbps)";
}
outFile << '\n';

for (auto const& [requestID, measures] : mRequestKVCacheTranfserMeasure)
{
outFile << requestID;

for (auto const& measure : measures)
{
outFile << "," << measure.delay << "," << measure.duration << "," << measure.bandwidth;
}
outFile << '\n';
}

outFile.close();
}
}

private:
std::map<LlmRequest::RequestIdType, std::vector<Measure>> mRequestKVCacheTranfserMeasure;
std::string mOutputPath;
std::mutex mMutex;
std::optional<bool> mIsSender;
};

} // namespace tensorrt_llm::batch_manager
47 changes: 43 additions & 4 deletions cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
#include "tensorrt_llm/executor/cache_transmission/agent_utils/connection.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"

#include <filesystem>

namespace tensorrt_llm::batch_manager
{

Expand All @@ -30,6 +32,21 @@ static int32_t tagFromRequestId(LlmRequest::RequestIdType requestId)
return ((requestId & 0xFFF) << 8) | (kDATA_TAG & 0xFF);
}

namespace fs = std::filesystem;

static fs::path getTransferOutputPath(char const* tag)
{
auto outputPath = common::getEnvKVCacheTransferOutputPath();
if (!outputPath.empty())
{
auto rank = mpi::MpiComm::world().getRank();
auto path = fs::path(outputPath);
fs::create_directories(path);
return path / ("rank_" + std::to_string(rank) + "_" + tag + ".csv");
}
return {};
}

DataSenderImpl::DataSenderImpl(executor::kv_cache::ConnectionManager* manager,
executor::kv_cache::CacheState selfCacheState, SizeType32 selfIndex, std::unique_ptr<BaseCacheFormatter> formatter)
: mManager{manager}
Expand All @@ -39,7 +56,6 @@ DataSenderImpl::DataSenderImpl(executor::kv_cache::ConnectionManager* manager,
{
TLLM_CHECK(mManager);
TLLM_CHECK(mManager->getCommState().getSelfIdx() == selfIndex);
mFormatter->markAsSender(true);
}

[[nodiscard]] RequestInfo DataSenderImpl::recvRequestInfo()
Expand Down Expand Up @@ -86,7 +102,8 @@ DataSenderImpl::DataSenderImpl(executor::kv_cache::ConnectionManager* manager,
if (it == mRequestToSession.end())
{
auto session = TransferSession(std::vector<Connection const*>(peerRelativeRanks.size(), nullptr),
DataContext{tagFromRequestId(requestId)}, mSelfState, info.getTransState(), mBufferManager);
DataContext{tagFromRequestId(requestId)}, mSelfState, info.getTransState(), mBufferManager, nullptr,
!common::getEnvKVCacheTransferOutputPath().empty());
it = mRequestToSession.emplace(requestId, std::move(session)).first;
}
it->second.setConnection(peerIdx, connection);
Expand Down Expand Up @@ -125,6 +142,17 @@ void DataSenderImpl::release(LlmRequest::RequestIdType requestId)
auto it = mRequestToSession.find(requestId);
TLLM_CHECK(it != mRequestToSession.end());
std::unique_lock<std::mutex> lk(mMtxForMap);
if (!common::getEnvKVCacheTransferOutputPath().empty())
{
if (!mMeasuresFile.is_open())
{
auto outputPath = getTransferOutputPath("send");
mMeasuresFile.open(outputPath);
TLLM_CHECK_WITH_INFO(
mMeasuresFile.is_open(), "Failed to open transfer output file: %s", outputPath.string().c_str());
}
it->second.exportMeasure(mMeasuresFile, true);
}
mRequestToSession.erase(it);
}

Expand All @@ -137,7 +165,6 @@ DataReceiverImpl::DataReceiverImpl(executor::kv_cache::ConnectionManager* manage
TLLM_CHECK(mManager);
TLLM_CHECK(mManager->getCommState().getSelfIdx() == selfIndex);
TLLM_CHECK(mFormatter);
mFormatter->markAsSender(false);
}

TransferSession DataReceiverImpl::sendRequestInfo(LlmRequest const& llmRequest)
Expand Down Expand Up @@ -203,12 +230,24 @@ TransferSession DataReceiverImpl::sendRequestInfo(LlmRequest const& llmRequest)
}
auto const& resource = getReceiveCacheResource(llmRequest);
return TransferSession(std::move(counterPartConnections), DataContext{tagFromRequestId(requestId)}, mSelfState,
contextState, resource->mBufferManager, &llmRequest);
contextState, resource->mBufferManager, &llmRequest, !common::getEnvKVCacheTransferOutputPath().empty());
}

void DataReceiverImpl::receiveSync(TransferSession& session)
{
mFormatter->unformat(session);
if (!common::getEnvKVCacheTransferOutputPath().empty())
{
std::unique_lock<std::mutex> lock(mMeasuresFileMutex);
if (!mMeasuresFile.is_open())
{
auto outputPath = getTransferOutputPath("recv");
mMeasuresFile.open(outputPath);
TLLM_CHECK_WITH_INFO(
mMeasuresFile.is_open(), "Failed to open transfer output file: %s", outputPath.string().c_str());
}
session.exportMeasure(mMeasuresFile, false);
}
}

void DataReceiverImpl::sendRequestInfo(executor::kv_cache::Connection const* connection, RequestInfo const& info)
Expand Down
5 changes: 5 additions & 0 deletions cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
#include "tensorrt_llm/common/envUtils.h"
#include "tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h"

#include <fstream>

namespace tensorrt_llm::batch_manager
{
struct TransceiverTag
Expand Down Expand Up @@ -67,6 +69,7 @@ class DataSenderImpl : public DataSender, public TransceiverTag
std::unique_ptr<BaseCacheFormatter> mFormatter;
std::mutex mMtxForMap;
runtime::BufferManager mBufferManager;
std::ofstream mMeasuresFile;
};

class DataReceiverImpl : public DataReceiver, public TransceiverTag
Expand Down Expand Up @@ -103,6 +106,8 @@ class DataReceiverImpl : public DataReceiver, public TransceiverTag
std::unique_ptr<BaseCacheFormatter> mFormatter;
std::unordered_map<std::string, std::unique_ptr<ReceiveCacheResource>> mProcessToResources;
std::mutex mProcessIoResouceMutex;
std::ofstream mMeasuresFile;
std::mutex mMeasuresFileMutex;
};

} // namespace tensorrt_llm::batch_manager
4 changes: 2 additions & 2 deletions cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ void MLACacheFormatter::format(TransferSession& session)
}
double cacheTransferTime
= std::max(0.0, std::chrono::duration<double, std::milli>(endTime - startTime).count());
kvCacheMeasureHelper.appendKVCacheTransfer(llmRequest.mRequestId, delay, cacheTransferTime, size);
session.appendMeasure(delay, cacheTransferTime, size);
};

if (connections.size() > 1)
Expand Down Expand Up @@ -433,7 +433,7 @@ void MLACacheFormatter::unformat(TransferSession& session)
}
double cacheTransferTime
= std::max(0.0, std::chrono::duration<double, std::milli>(endTime - startTime).count());
kvCacheMeasureHelper.appendKVCacheTransfer(ctxReqId, delay, cacheTransferTime, size);
session.appendMeasure(delay, cacheTransferTime, size);
};

if (pickUpConnections.size() > 1)
Expand Down
2 changes: 1 addition & 1 deletion cpp/tensorrt_llm/common/envUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ size_t getEnvAllReduceWorkspaceSize()
return workspaceSize;
}

std::string getEnvKVCacheTransferOutputPath()
std::string const& getEnvKVCacheTransferOutputPath()
{
static std::string outputPath = getStrEnv("TRTLLM_KVCACHE_TIME_OUTPUT_PATH").value_or("");
return outputPath;
Expand Down
2 changes: 1 addition & 1 deletion cpp/tensorrt_llm/common/envUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ bool getEnvDisableKVCacheTransferOverlap();

bool getEnvEnableReceiveKVCacheParallel();

std::string getEnvKVCacheTransferOutputPath();
std::string const& getEnvKVCacheTransferOutputPath();

bool getEnvTryZCopyForKVCacheTransfer();

Expand Down