Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion tensorrt_llm/_torch/distributed/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch.distributed as dist

from tensorrt_llm._utils import (mpi_allgather, mpi_barrier, mpi_broadcast,
mpi_comm, mpi_isend, mpi_recv)
mpi_comm, mpi_isend, mpi_recv, mpi_send)
from tensorrt_llm.mapping import Mapping


Expand Down Expand Up @@ -113,6 +113,10 @@ def isend(self, buf: np.ndarray, dest, tag=0):
# non-blocking send numpy buffer
return mpi_isend(buf, dest, tag)

def send(self, buf: np.ndarray, dest, tag=0):
# blocking send numpy buffer
mpi_send(buf, dest, tag)

def recv(self, buf: np.ndarray, src, tag=0):
# in-place recv numpy buffer
return mpi_recv(buf, src, tag)
Expand Down Expand Up @@ -200,6 +204,7 @@ class PPComm:
# PP communication using torch.distributed with nccl backend
def __init__(self, global_mapping: Mapping):
self.mapping = global_mapping
self.send_event = torch.cuda.Event()
if not dist.is_initialized():
master_ip = os.getenv("MASTER_ADDR", "localhost")
master_port = os.getenv("MASTER_PORT", "6000")
Expand Down
1 change: 0 additions & 1 deletion tensorrt_llm/_torch/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,6 @@ def __pp_init__(self):
layer for layer in self.layers[:config.num_hidden_layers]
if not layer.is_missing()
]
print(f"{self._local_layers=}, {self.pp_layer_list=}")

# add create_pipeline_interface method
pp_interface_keys = ["hidden_states", "residual"]
Expand Down
4 changes: 4 additions & 0 deletions tensorrt_llm/_torch/pipeline_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,11 @@ def recv(self):

def send(self):
"""Send tensors to next rank."""
# pp_comm.send returns after nccl send kernel is enqueued. Event sync waits till prev kernel
# finishes and avoids earlier PP rank executing multiple microbatches ahead of later rank.
self._pp_comm.send_event.synchronize()
if self.hidden_states is not None:
self._pp_comm.send(self.hidden_states, tag=self.tag)
if self.residual is not None:
self._pp_comm.send(self.residual, tag=self.tag)
self._pp_comm.send_event.record()
12 changes: 5 additions & 7 deletions tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
import tensorrt_llm
import tensorrt_llm.bindings as tllm
import tensorrt_llm.bindings.executor as trtllm
from tensorrt_llm._utils import (mpi_allgather, mpi_broadcast,
str_dtype_to_binding, torch_dtype_to_str)
from tensorrt_llm._utils import str_dtype_to_binding, torch_dtype_to_str
from tensorrt_llm.bindings.executor import ExecutorConfig
from tensorrt_llm.logger import logger
from tensorrt_llm.lora_manager import LoraConfig, load_torch_hf_lora
Expand Down Expand Up @@ -162,11 +161,11 @@ def estimate_max_kv_cache_tokens(py_executor: PyExecutor,
req = create_dummy_context_requests(max_num_tokens, origin_seq_len,
vocab_size)
req_ids = py_executor.enqueue_requests(req)
req_ids = mpi_broadcast(req_ids, root=0)
req_ids = py_executor.dist.broadcast(req_ids, root=0)
py_executor.start_worker()
py_executor.await_responses(req_ids)
# sync all ranks after processing dummy requests
mpi_allgather(0)
# sync all ranks after processing dummy requests. mpi barrier causes hang, so allgather is used.
py_executor.dist.allgather(0)

torch_peak_memory = torch.cuda.memory_stats()["allocated_bytes.all.peak"]

Expand Down Expand Up @@ -204,8 +203,7 @@ def estimate_max_kv_cache_tokens(py_executor: PyExecutor,
"kv_cache_manager").shutdown()

py_executor.shutdown()
# sync all ranks after creating new pyExecutor
mpi_allgather(0)
py_executor.dist.allgather(0)

return kv_cache_max_tokens

Expand Down
16 changes: 10 additions & 6 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,7 @@ def _maybe_get_cuda_graph(
spec_metadata = None

pipeline_interface = None
if self.mapping.pp_rank > 0:
if not self.mapping.is_first_pp_rank():
pipeline_interface = self.model.create_pipeline_interface(
batch_size)
self._cuda_graphs[batch_size] = DecodingCUDAGraphRunner(
Expand Down Expand Up @@ -1189,10 +1189,9 @@ def _prepare_tp_inputs(

if self.mapping.has_pp():
pipeline_interface = None
if self.mapping.pp_rank > 0:
if not self.mapping.is_first_pp_rank():
pipeline_interface = self.model.create_pipeline_interface(
inputs['input_ids'].shape[0])
pipeline_interface.recv()
inputs['pipeline_interface'] = pipeline_interface

num_generation_tokens = len(generation_requests) + len(
Expand Down Expand Up @@ -1323,7 +1322,6 @@ def _prepare_tp_inputs_no_cache(
if self.mapping.pp_rank > 0:
pipeline_interface = self.model.create_pipeline_interface(
inputs['input_ids'].shape[0])
pipeline_interface.recv()
inputs['pipeline_interface'] = pipeline_interface

return inputs, None
Expand Down Expand Up @@ -1604,7 +1602,9 @@ def forward(self,
inputs.update(extra_model_inputs)
self.last_spec_metadata = spec_metadata

if self.mapping.has_pp() and not self.mapping.is_last_pp_rank():
if not self.mapping.is_first_pp_rank():
inputs['pipeline_interface'].recv()
if not self.mapping.is_last_pp_rank():
pp_interface = self._forward_step_intermediate(inputs)
pp_interface.send()
return self._post_forward_intermediate(inputs, pp_interface,
Expand Down Expand Up @@ -1637,7 +1637,9 @@ def forward(self,
self.iter_counter += 1

if maybe_graph is None:
if self.mapping.has_pp() and not self.mapping.is_last_pp_rank():
if not self.mapping.is_first_pp_rank():
inputs['pipeline_interface'].recv()
if not self.mapping.is_last_pp_rank():
pp_interface = self._forward_step_intermediate(inputs)
pp_interface.send()
outputs = self._post_forward_intermediate(
Expand All @@ -1657,6 +1659,8 @@ def forward(self,
self._cuda_graph_mem_pool)
self._cuda_graph_mem_pool = pool

if not self.mapping.is_first_pp_rank():
inputs['pipeline_interface'].recv()
outputs = maybe_graph.run(inputs)
if not self.mapping.is_last_pp_rank():
pp_interface = PipelineInterface(*outputs)
Expand Down
12 changes: 2 additions & 10 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,6 @@ def __init__(self,
self.micro_batches: List[BatchStatePP
| None] = [None] * self.num_micro_batches
self.send_handles = [None] * self.num_micro_batches
# one handle each for metadata and serialized new_reqs buffer
self.send_new_reqs_handle = [None] * 2

self.inflight_req_ids = ReqIdsSet()
self.canceled_req_ids = ReqIdsSet()
Expand Down Expand Up @@ -1139,10 +1137,7 @@ def _broadcast_new_requests(self, new_requests):
self.dist.recv(metadata_arr, self.dist.prev_pp_rank, tag)

if not self.dist.is_last_pp_rank:
if self.send_new_reqs_handle[0] is not None:
self.send_new_reqs_handle[0].Wait()
self.send_new_reqs_handle[0] = self.dist.isend(
metadata_arr, self.dist.next_pp_rank, tag)
self.dist.send(metadata_arr, self.dist.next_pp_rank, tag)

# 2. send serialized buffer when new requests is not empty
num_new_requests = metadata_arr[0]
Expand All @@ -1153,10 +1148,7 @@ def _broadcast_new_requests(self, new_requests):
self.dist.recv(buf, self.dist.prev_pp_rank, tag)

if not self.dist.is_last_pp_rank:
if self.send_new_reqs_handle[1] is not None:
self.send_new_reqs_handle[1].Wait()
self.send_new_reqs_handle[1] = self.dist.isend(
buf, self.dist.next_pp_rank, tag)
self.dist.send(buf, self.dist.next_pp_rank, tag)

if not self.dist.is_first_pp_rank:
new_requests = dill.loads(buf.tobytes()) # nosec B301
Expand Down
8 changes: 8 additions & 0 deletions tensorrt_llm/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,14 @@ def mpi_isend(buf, dest, tag=0):
return None


def mpi_send(buf, dest, tag=0):
# send in buf-like objects (e.g. numpy array)
# return request handle if ENABLE_MULTI_DEVICE
if ENABLE_MULTI_DEVICE:
mpi_comm().Send(buf, dest, tag=tag)
return None


def mpi_recv(buf, source, tag):
# recv in buf-like object (e.g. numpy array)
if ENABLE_MULTI_DEVICE:
Expand Down
8 changes: 4 additions & 4 deletions tests/integration/defs/accuracy/test_llm_api_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def test_bfloat16(self, attn_backend, torch_compile):

@parametrize_with_ids("torch_compile", [False, True])
@parametrize_with_ids("attn_backend", ["TRTLLM", "FLASHINFER"])
@pytest.mark.parametrize("tp_size,pp_size", [(4, 1), (2, 2)],
ids=["tp4", "tp2pp2"])
@pytest.mark.parametrize("tp_size,pp_size", [(4, 1), (2, 2), (1, 4)],
ids=["tp4", "tp2pp2", "pp4"])
def test_bfloat16_4gpus(self, tp_size, pp_size, attn_backend,
torch_compile):
if torch_compile and pp_size > 1:
Expand Down Expand Up @@ -135,8 +135,8 @@ def test_fp8(self, fp8kv, attn_backend, torch_compile):
@parametrize_with_ids("torch_compile", [False, True])
@parametrize_with_ids("attn_backend", ["TRTLLM", "FLASHINFER"])
@parametrize_with_ids("fp8kv", [False, True])
@pytest.mark.parametrize("tp_size,pp_size", [(4, 1), (2, 2)],
ids=["tp4", "tp2pp2"])
@pytest.mark.parametrize("tp_size,pp_size", [(4, 1), (2, 2), (1, 4)],
ids=["tp4", "tp2pp2", "pp4"])
def test_fp8_4gpus(self, tp_size, pp_size, fp8kv, attn_backend,
torch_compile):
if torch_compile:
Expand Down
3 changes: 3 additions & 0 deletions tests/integration/test_lists/test-db/l0_dgx_h100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ l0_dgx_h100:
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp2pp2-attn_backend=TRTLLM-torch_compile]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp2pp2-fp8kv-attn_backend=TRTLLM]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp2pp2-fp8kv-attn_backend=TRTLLM-torch_compile]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv-attn_backend=TRTLLM]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-attention_dp]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-cuda_graph]
Expand Down Expand Up @@ -187,6 +188,7 @@ l0_dgx_h100:
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[tp4-attn_backend=FLASHINFER-torch_compile]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[tp2pp2-attn_backend=FLASHINFER]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[tp2pp2-attn_backend=FLASHINFER-torch_compile]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[pp4-attn_backend=FLASHINFER]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-attn_backend=FLASHINFER]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-attn_backend=FLASHINFER-torch_compile]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv-attn_backend=FLASHINFER]
Expand All @@ -195,6 +197,7 @@ l0_dgx_h100:
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp2pp2-attn_backend=FLASHINFER-torch_compile]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp2pp2-fp8kv-attn_backend=FLASHINFER]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp2pp2-fp8kv-attn_backend=FLASHINFER-torch_compile]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv-attn_backend=FLASHINFER]
- condition:
ranges:
system_gpu_count:
Expand Down
4 changes: 0 additions & 4 deletions tests/integration/test_lists/waives.txt
Original file line number Diff line number Diff line change
Expand Up @@ -361,10 +361,6 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpu
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=2-cuda_graph] SKIP (https://nvbugs/5170160)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=2-overlap_scheduler] SKIP (https://nvbugs/5170160)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=2-attention_dp-cuda_graph-overlap_scheduler] SKIP (https://nvbugs/5170160)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[pp4-cuda_graph] SKIP (https://nvbugs/5181511)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[pp4-attention_dp-cuda_graph-overlap_scheduler] SKIP (https://nvbugs/5181511)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[pp4-cuda_graph] SKIP (https://nvbugs/5181511)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[pp4-attention_dp-cuda_graph-overlap_scheduler] SKIP (https://nvbugs/5181511)
full:L40S/accuracy/test_cli_flow.py::TestGemma2_9BIt::test_auto_dtype SKIP (https://nvbugs/5176851)
full:L40S/accuracy/test_cli_flow.py::TestGemma2_9BIt::test_weight_only[int8] SKIP (https://nvbugs/5176851)
full:L40S/accuracy/test_cli_flow.py::TestGemma2_9BIt::test_weight_only[int4] SKIP (https://nvbugs/5176851)
Expand Down