Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ __global__ void lamport_initialize_kernel(float* ptr, int size)

void lamport_initialize(void* ptr, int bytes, cudaStream_t stream)
{
lamport_initialize_kernel<<<bytes / 128, 128, 0, stream>>>(reinterpret_cast<float*>(ptr), bytes / sizeof(float));
int grid_size = (bytes + 127) / 128;
lamport_initialize_kernel<<<grid_size, 128, 0, stream>>>(reinterpret_cast<float*>(ptr), bytes / sizeof(float));
}

Workspace::Workspace(int rank, int tp_size, int max_token_num, int hidden_dim,
Expand Down
4 changes: 4 additions & 0 deletions cpp/tensorrt_llm/kernels/customAllReduceKernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1989,6 +1989,10 @@ void residualRmsNorm(
void lamportInitialize(void* buffer, size_t size, nvinfer1::DataType dataType, cudaStream_t stream)
{
sync_check_cuda_error(stream);
if (size == 0)
{
return;
}
switch (dataType)
{
case nvinfer1::DataType::kFLOAT:
Expand Down
3 changes: 2 additions & 1 deletion tensorrt_llm/_torch/distributed/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def get_deepseek_allreduce_workspace(mapping: Mapping) -> torch.LongTensor:
if mapping not in deepseek_allreduce_workspaces:
ipc_buffers, workspace = CustomAllReduceHelper.allocate_allreduce_fusion_workspace(
mapping,
CustomAllReduceHelper.max_workspace_size_auto(mapping.tp_size),
CustomAllReduceHelper.max_workspace_size_auto(
mapping.tp_size, support_deterministic=False),
)
deepseek_allreduce_workspaces[mapping] = (ipc_buffers, workspace)
return deepseek_allreduce_workspaces[mapping][1]
Expand Down
4 changes: 4 additions & 0 deletions tensorrt_llm/llmapi/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,10 @@ def _build_model(self):
if self.args.kv_cache_config is not None:
executor_config.kv_cache_config = PybindMirror.maybe_to_pybind(
self.args.kv_cache_config)
if os.getenv("FORCE_DETERMINISTIC", "0") == "1":
# Disable KV cache reuse for deterministic mode
executor_config.kv_cache_config.enable_block_reuse = False
executor_config.kv_cache_config.enable_partial_reuse = False
if self.args.peft_cache_config is not None:
executor_config.peft_cache_config = PybindMirror.maybe_to_pybind(
self.args.peft_cache_config)
Expand Down
6 changes: 3 additions & 3 deletions tensorrt_llm/plugin/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,8 +704,8 @@ def set_workspace_tensor(self,
)

@staticmethod
def max_workspace_size_auto(tp_size: int) -> int:
if force_all_reduce_deterministic():
def max_workspace_size_auto(tp_size: int, support_deterministic) -> int:
if force_all_reduce_deterministic() and support_deterministic:
workspace_size = os.getenv("FORCE_ALLREDUCE_KERNEL_WORKSPACE_SIZE",
"1000000000")
return int(workspace_size)
Expand Down Expand Up @@ -746,7 +746,7 @@ def allocate_workspace(mapping: Mapping,
lamport_buffers_0.local_ptr,
lamport_buffers_1.local_ptr,
lamport_buffers_2.local_ptr,
size * mapping.tp_size,
lamport_buffers_size,
)
buffers = [
ipc_buffers_ping, ipc_buffers_pong, ipc_barriers_in,
Expand Down