Skip to content

Commit b7bf797

Browse files
Inclusion of InternVLChatModel In PP_SUPPORTED_MODELS(Pipeline Parallelism) (vllm-project#7860)
1 parent d3b718f commit b7bf797

File tree

6 files changed

+90
-35
lines changed

6 files changed

+90
-35
lines changed

tests/distributed/test_pipeline_parallel.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,26 @@
1818
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
1919

2020

21-
@pytest.mark.parametrize(("TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, "
22-
"MODEL_NAME, DIST_BACKEND"),
23-
[
24-
(2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"),
25-
(2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
26-
(1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
27-
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"),
28-
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
29-
(1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
30-
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray"),
31-
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
32-
(2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
33-
(2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray"),
34-
])
21+
@pytest.mark.parametrize(
22+
("TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, TRUST_REMOTE_CODE, "
23+
"MODEL_NAME, DIST_BACKEND"),
24+
[
25+
(2, 2, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
26+
(2, 2, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
27+
(1, 3, 0, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
28+
(1, 4, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
29+
(1, 4, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
30+
(1, 3, 0, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
31+
(1, 4, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
32+
(1, 4, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
33+
(2, 2, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
34+
(2, 2, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
35+
(2, 2, 1, 1, 1, "internlm/internlm2_5-7b-chat", "ray"),
36+
],
37+
)
3538
@fork_new_process_for_each_test
36-
def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
37-
DIST_BACKEND):
39+
def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL,
40+
TRUST_REMOTE_CODE, MODEL_NAME, DIST_BACKEND):
3841
if VLLM_MULTI_NODE and DIST_BACKEND == "mp":
3942
pytest.skip("Skipping multi-node pipeline parallel test for "
4043
"multiprocessing distributed backend")
@@ -71,6 +74,9 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
7174
if EAGER_MODE:
7275
pp_args.append("--enforce-eager")
7376
tp_args.append("--enforce-eager")
77+
if TRUST_REMOTE_CODE:
78+
pp_args.append("--trust-remote-code")
79+
tp_args.append("--trust-remote-code")
7480
pp_env = None
7581
if (DIST_BACKEND == "ray" and TP_SIZE == 2 and PP_SIZE == 2
7682
and CHUNKED_PREFILL):

tests/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,12 @@ def compare_two_settings(model: str,
178178
env2: The second set of environment variables to pass to the API server.
179179
"""
180180

181-
tokenizer = AutoTokenizer.from_pretrained(model)
181+
trust_remote_code = "--trust-remote-code"
182+
if trust_remote_code in arg1 or trust_remote_code in arg2:
183+
tokenizer = AutoTokenizer.from_pretrained(model,
184+
trust_remote_code=True)
185+
else:
186+
tokenizer = AutoTokenizer.from_pretrained(model)
182187

183188
prompt = "Hello, my name is"
184189
token_ids = tokenizer(prompt)["input_ids"]

vllm/config.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,18 +35,20 @@
3535
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 4096
3636

3737
_PP_SUPPORTED_MODELS = [
38-
"AquilaModel",
3938
"AquilaForCausalLM",
39+
"AquilaModel",
4040
"DeepseekV2ForCausalLM",
41+
"GPT2LMHeadModel",
42+
"InternLM2ForCausalLM",
4143
"InternLMForCausalLM",
44+
"InternVLChatModel",
4245
"JAISLMHeadModel",
4346
"LlamaForCausalLM",
4447
"LLaMAForCausalLM",
4548
"MistralForCausalLM",
46-
"Phi3ForCausalLM",
47-
"GPT2LMHeadModel",
4849
"MixtralForCausalLM",
4950
"NemotronForCausalLM",
51+
"Phi3ForCausalLM",
5052
"Qwen2ForCausalLM",
5153
"Qwen2MoeForCausalLM",
5254
"QWenLMHeadModel",

vllm/model_executor/models/internlm2.py

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
# -*- coding: utf-8 -*-
22
from functools import partial
3-
from typing import Any, Dict, Iterable, List, Optional, Tuple
3+
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
44

55
import torch
66
from torch import nn
77
from transformers import PretrainedConfig
88

99
from vllm.attention import Attention, AttentionMetadata
1010
from vllm.config import CacheConfig
11-
from vllm.distributed import (get_tensor_model_parallel_rank,
11+
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
1212
get_tensor_model_parallel_world_size,
1313
split_tensor_along_last_dim,
1414
tensor_model_parallel_all_gather)
@@ -28,6 +28,9 @@
2828
from vllm.model_executor.sampling_metadata import SamplingMetadata
2929
from vllm.sequence import IntermediateTensors
3030

31+
from .utils import (is_pp_missing_parameter,
32+
make_empty_intermediate_tensors_factory, make_layers)
33+
3134

3235
class InternLM2MLP(nn.Module):
3336

@@ -234,6 +237,7 @@ def __init__(
234237
config: PretrainedConfig,
235238
cache_config: Optional[CacheConfig] = None,
236239
quant_config: Optional[QuantizationConfig] = None,
240+
prefix: str = "",
237241
) -> None:
238242
super().__init__()
239243
self.config = config
@@ -243,11 +247,15 @@ def __init__(
243247
config.vocab_size,
244248
config.hidden_size,
245249
)
246-
self.layers = nn.ModuleList([
247-
InternLMDecoderLayer(config, cache_config, quant_config)
248-
for _ in range(config.num_hidden_layers)
249-
])
250+
self.start_layer, self.end_layer, self.layers = make_layers(
251+
config.num_hidden_layers,
252+
lambda prefix: InternLMDecoderLayer(config, cache_config,
253+
quant_config),
254+
prefix=f"{prefix}.layers")
250255
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
256+
self.make_empty_intermediate_tensors = (
257+
make_empty_intermediate_tensors_factory(
258+
["hidden_states", "residual"], config.hidden_size))
251259

252260
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
253261
return self.tok_embeddings(input_ids)
@@ -260,21 +268,31 @@ def forward(
260268
attn_metadata: AttentionMetadata,
261269
intermediate_tensors: IntermediateTensors = None,
262270
inputs_embeds: Optional[torch.Tensor] = None,
263-
) -> torch.Tensor:
264-
if inputs_embeds is not None:
265-
hidden_states = inputs_embeds
271+
) -> Union[torch.Tensor, IntermediateTensors]:
272+
if get_pp_group().is_first_rank:
273+
if inputs_embeds is not None:
274+
hidden_states = inputs_embeds
275+
else:
276+
hidden_states = self.tok_embeddings(input_ids)
277+
residual = None
266278
else:
267-
hidden_states = self.tok_embeddings(input_ids)
268-
residual = None
269-
for i in range(len(self.layers)):
279+
assert intermediate_tensors is not None
280+
hidden_states = intermediate_tensors["hidden_states"]
281+
residual = intermediate_tensors["residual"]
282+
for i in range(self.start_layer, self.end_layer):
270283
layer = self.layers[i]
271284
hidden_states, residual = layer(
272285
positions,
273286
hidden_states,
274-
kv_caches[i],
287+
kv_caches[i - self.start_layer],
275288
attn_metadata,
276289
residual,
277290
)
291+
if not get_pp_group().is_last_rank:
292+
return IntermediateTensors({
293+
"hidden_states": hidden_states,
294+
"residual": residual
295+
})
278296
hidden_states, _ = self.norm(hidden_states, residual)
279297
return hidden_states
280298

@@ -298,6 +316,8 @@ def __init__(
298316
self.output.weight = self.model.tok_embeddings.weight
299317
self.logits_processor = LogitsProcessor(config.vocab_size)
300318
self.sampler = Sampler()
319+
self.make_empty_intermediate_tensors = (
320+
self.model.make_empty_intermediate_tensors)
301321

302322
def forward(
303323
self,
@@ -308,7 +328,7 @@ def forward(
308328
intermediate_tensors: IntermediateTensors,
309329
) -> torch.Tensor:
310330
hidden_states = self.model(input_ids, positions, kv_caches,
311-
attn_metadata)
331+
attn_metadata, intermediate_tensors)
312332
return hidden_states
313333

314334
def compute_logits(
@@ -345,6 +365,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
345365
# Skip loading extra bias for GPTQ models.
346366
if name.endswith(".bias") and name not in params_dict:
347367
continue
368+
if is_pp_missing_parameter(name, self):
369+
continue
348370
param = params_dict[name]
349371
weight_loader = param.weight_loader
350372
weight_loader(param, loaded_weight, shard_id)
@@ -353,6 +375,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
353375
# Skip loading extra bias for GPTQ models.
354376
if name.endswith(".bias") and name not in params_dict:
355377
continue
378+
if is_pp_missing_parameter(name, self):
379+
continue
356380
param = params_dict[name]
357381
weight_loader = getattr(param, "weight_loader",
358382
default_weight_loader)

vllm/model_executor/models/internvl.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,8 @@ def __init__(self,
341341
nn.Linear(llm_hidden_size, llm_hidden_size))
342342

343343
self.img_context_token_id = None
344+
self.make_empty_intermediate_tensors = (
345+
self.language_model.make_empty_intermediate_tensors)
344346

345347
def pixel_shuffle(self, x, scale_factor=0.5):
346348
n, w, h, c = x.size()
@@ -461,7 +463,7 @@ def forward(
461463
positions,
462464
kv_caches,
463465
attn_metadata,
464-
None,
466+
intermediate_tensors,
465467
inputs_embeds=inputs_embeds)
466468
return hidden_states
467469

vllm/model_executor/models/utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from vllm.model_executor.model_loader.loader import build_model
1313
from vllm.model_executor.models import ModelRegistry
1414
from vllm.multimodal.base import NestedTensors
15+
from vllm.sequence import IntermediateTensors
1516
from vllm.utils import is_pin_memory_available
1617

1718

@@ -279,3 +280,18 @@ def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
279280
if name.startswith(missing_layer_name):
280281
return True
281282
return False
283+
284+
285+
def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int):
286+
287+
def make_empty_intermediate_tensors(
288+
batch_size: int, dtype: torch.dtype,
289+
device: torch.device) -> IntermediateTensors:
290+
return IntermediateTensors({
291+
key: torch.zeros((batch_size, hidden_size),
292+
dtype=dtype,
293+
device=device)
294+
for key in keys
295+
})
296+
297+
return make_empty_intermediate_tensors

0 commit comments

Comments
 (0)