Skip to content

Commit 9ab2c02

Browse files
authored
Support sequence parallelism combined with pipeline parallelism (#18243)
Signed-off-by: cascade812 <[email protected]>
1 parent 66e63e8 commit 9ab2c02

File tree

3 files changed

+74
-27
lines changed

3 files changed

+74
-27
lines changed

tests/distributed/test_sequence_parallel.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
class ParallelSetup(NamedTuple):
2828
tp_size: int
29+
pp_size: int
2930
sp_enabled: bool
3031
eager_mode: bool
3132
chunked_prefill: bool
@@ -60,25 +61,50 @@ def __post_init__(self):
6061
def detailed(
6162
*,
6263
tp_base: int = 2,
64+
pp_base: int = 1,
6365
multi_node_only: bool = False,
6466
task: TaskOption = "auto",
6567
load_format: Optional[str] = None,
6668
):
6769
return SPTestSettings(
6870
parallel_setups=[
6971
ParallelSetup(tp_size=tp_base,
72+
pp_size=pp_base,
7073
sp_enabled=True,
7174
eager_mode=False,
7275
chunked_prefill=False),
7376
ParallelSetup(tp_size=tp_base,
77+
pp_size=pp_base,
7478
sp_enabled=True,
7579
eager_mode=False,
7680
chunked_prefill=True),
7781
ParallelSetup(tp_size=tp_base,
82+
pp_size=pp_base,
7883
sp_enabled=True,
7984
eager_mode=True,
8085
chunked_prefill=False),
8186
ParallelSetup(tp_size=tp_base,
87+
pp_size=pp_base,
88+
sp_enabled=True,
89+
eager_mode=True,
90+
chunked_prefill=True),
91+
ParallelSetup(tp_size=tp_base,
92+
pp_size=2 * pp_base,
93+
sp_enabled=True,
94+
eager_mode=False,
95+
chunked_prefill=False),
96+
ParallelSetup(tp_size=tp_base,
97+
pp_size=2 * pp_base,
98+
sp_enabled=True,
99+
eager_mode=False,
100+
chunked_prefill=True),
101+
ParallelSetup(tp_size=tp_base,
102+
pp_size=2 * pp_base,
103+
sp_enabled=True,
104+
eager_mode=True,
105+
chunked_prefill=False),
106+
ParallelSetup(tp_size=tp_base,
107+
pp_size=2 * pp_base,
82108
sp_enabled=True,
83109
eager_mode=True,
84110
chunked_prefill=True)
@@ -94,13 +120,20 @@ def detailed(
94120
def fast(
95121
*,
96122
tp_base: int = 2,
123+
pp_base: int = 1,
97124
task: TaskOption = "auto",
98125
multi_node_only: bool = False,
99126
load_format: Optional[str] = None,
100127
):
101128
return SPTestSettings(
102129
parallel_setups=[
103130
ParallelSetup(tp_size=tp_base,
131+
pp_size=pp_base,
132+
sp_enabled=True,
133+
eager_mode=False,
134+
chunked_prefill=False),
135+
ParallelSetup(tp_size=tp_base,
136+
pp_size=2 * pp_base,
104137
sp_enabled=True,
105138
eager_mode=False,
106139
chunked_prefill=False),
@@ -136,6 +169,7 @@ def _compare_sp(
136169
):
137170
(
138171
tp_size,
172+
pp_size,
139173
sp_enabled,
140174
eager_mode,
141175
chunked_prefill,
@@ -167,7 +201,6 @@ def _compare_sp(
167201
else:
168202
model_info.check_available_online(on_fail="skip")
169203

170-
pp_size = 1
171204
if num_gpus_available < tp_size * pp_size:
172205
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
173206
if VLLM_MULTI_NODE and distributed_backend == "mp":
@@ -256,7 +289,7 @@ def _compare_sp(
256289

257290
SP_TEXT_GENERATION_MODELS = {
258291
# [Decoder-only]
259-
"meta-llama/Llama-3.2-1B-Instruct": SPTestSettings.detailed(),
292+
"meta-llama/Llama-3.2-1B-Instruct": SPTestSettings.fast(),
260293
}
261294

262295
SP_TEST_MODELS = [

vllm/config.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4287,18 +4287,6 @@ def __post_init__(self):
42874287
self.compilation_config.level = CompilationLevel.PIECEWISE
42884288
self.compilation_config.set_splitting_ops_for_v1()
42894289

4290-
if self.parallel_config is not None and \
4291-
self.parallel_config.tensor_parallel_size > 1 and \
4292-
self.parallel_config.pipeline_parallel_size > 1 and \
4293-
self.compilation_config is not None and \
4294-
self.compilation_config.pass_config is not None and \
4295-
self.compilation_config.pass_config.enable_sequence_parallelism:
4296-
logger.warning_once(
4297-
"Sequence parallelism is not supported with pipeline "
4298-
"parallelism. Disabling sequence parallelism.")
4299-
self.compilation_config.pass_config.\
4300-
enable_sequence_parallelism = False
4301-
43024290
self._set_cudagraph_sizes()
43034291

43044292
if self.cache_config is not None and \

vllm/v1/worker/gpu_model_runner.py

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1056,6 +1056,40 @@ def apply_grammar_bitmask(
10561056
indices=out_indices,
10571057
)
10581058

1059+
def sync_and_slice_intermediate_tensors(
1060+
self, num_tokens: int, intermediate_tensors: IntermediateTensors,
1061+
sync_self: bool) -> IntermediateTensors:
1062+
1063+
assert self.intermediate_tensors is not None
1064+
1065+
tp = self.vllm_config.parallel_config.tensor_parallel_size
1066+
enabled_sp = self.vllm_config.compilation_config.pass_config. \
1067+
enable_sequence_parallelism
1068+
if enabled_sp:
1069+
# When sequence parallelism is enabled, we always pad num_tokens
1070+
# to be a multiple of tensor_parallel_size (tp) earlier
1071+
assert num_tokens % tp == 0
1072+
is_residual_scattered = tp > 1 and enabled_sp \
1073+
and num_tokens % tp == 0
1074+
1075+
# When sequence parallelism is enabled, the "residual" tensor is sharded
1076+
# across tensor parallel ranks, so each rank only needs its own slice.
1077+
if sync_self:
1078+
assert intermediate_tensors is not None
1079+
for k, v in intermediate_tensors.items():
1080+
is_scattered = "residual" and is_residual_scattered
1081+
copy_len = num_tokens // tp if is_scattered else \
1082+
num_tokens
1083+
self.intermediate_tensors[k][:copy_len].copy_(
1084+
v[:copy_len], non_blocking=True)
1085+
1086+
return IntermediateTensors({
1087+
k:
1088+
v[:num_tokens // tp]
1089+
if k == "residual" and is_residual_scattered else v[:num_tokens]
1090+
for k, v in self.intermediate_tensors.items()
1091+
})
1092+
10591093
@torch.inference_mode()
10601094
def execute_model(
10611095
self,
@@ -1131,15 +1165,8 @@ def execute_model(
11311165
if get_pp_group().is_first_rank:
11321166
intermediate_tensors = None
11331167
else:
1134-
assert intermediate_tensors is not None
1135-
assert self.intermediate_tensors is not None
1136-
for k, v in intermediate_tensors.items():
1137-
self.intermediate_tensors[k][:num_input_tokens].copy_(
1138-
v[:num_input_tokens], non_blocking=True)
1139-
intermediate_tensors = IntermediateTensors({
1140-
k: v[:num_input_tokens]
1141-
for k, v in self.intermediate_tensors.items()
1142-
})
1168+
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
1169+
num_input_tokens, intermediate_tensors, True)
11431170

11441171
# Run the decoder.
11451172
# Use persistent buffers for CUDA graphs.
@@ -1658,10 +1685,9 @@ def _dummy_run(
16581685
batch_size=self.max_num_tokens,
16591686
dtype=self.model_config.dtype,
16601687
device=self.device))
1661-
intermediate_tensors = IntermediateTensors({
1662-
k: v[:num_tokens]
1663-
for k, v in self.intermediate_tensors.items()
1664-
})
1688+
1689+
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
1690+
num_tokens, None, False)
16651691

16661692
with set_forward_context(attn_metadata,
16671693
self.vllm_config,

0 commit comments

Comments
 (0)