23
23
"""Inference-only Mixtral model."""
24
24
from typing import List , Optional , Tuple
25
25
26
- import numpy as np
27
-
28
26
import torch
29
27
import torch .nn .functional as F
30
28
33
31
34
32
from vllm .model_executor .input_metadata import InputMetadata
35
33
from vllm .model_executor .layers .attention import PagedAttention
34
+ from vllm .model_executor .layers .fused_moe import fused_moe
36
35
from vllm .model_executor .layers .layernorm import RMSNorm
37
36
from vllm .model_executor .layers .linear import (LinearMethodBase ,
38
- ReplicatedLinear ,
39
37
QKVParallelLinear ,
38
+ ReplicatedLinear ,
40
39
RowParallelLinear )
41
40
from vllm .model_executor .layers .rotary_embedding import get_rope
42
41
from vllm .model_executor .layers .sampler import Sampler
47
46
from vllm .model_executor .parallel_utils .parallel_state import (
48
47
get_tensor_model_parallel_rank , get_tensor_model_parallel_world_size )
49
48
from vllm .model_executor .sampling_metadata import SamplingMetadata
49
+ from vllm .model_executor .utils import set_weight_attrs
50
50
from vllm .model_executor .weight_utils import (default_weight_loader ,
51
51
hf_model_weights_iterator )
52
52
from vllm .sequence import SamplerOutput
53
53
54
54
KVCache = Tuple [torch .Tensor , torch .Tensor ]
55
55
56
56
57
- class MixtralMLP (nn .Module ):
57
+ class MixtralMoE (nn .Module ):
58
+ """A tensor-parallel MoE implementation for Mixtral that shards each expert
59
+ across all ranks.
60
+
61
+ Each expert's weights are sharded across all ranks and a fused MoE
62
+ kernel is used for the forward pass, and finally we reduce the outputs
63
+ across ranks.
64
+ """
58
65
59
66
def __init__ (
60
67
self ,
61
68
num_experts : int ,
69
+ top_k : int ,
62
70
hidden_size : int ,
63
71
intermediate_size : int ,
64
- linear_method : Optional [LinearMethodBase ] = None ,
65
- ) -> None :
72
+ params_dtype : Optional [torch . dtype ] = None ,
73
+ ):
66
74
super ().__init__ ()
67
- self .num_experts = num_experts
68
- self .ffn_dim = intermediate_size
69
- self .hidden_dim = hidden_size
70
-
71
- self .w1 = ReplicatedLinear (self .hidden_dim ,
72
- self .ffn_dim ,
73
- bias = False ,
74
- linear_method = linear_method )
75
- self .w2 = ReplicatedLinear (self .ffn_dim ,
76
- self .hidden_dim ,
77
- bias = False ,
78
- linear_method = linear_method )
79
- self .w3 = ReplicatedLinear (self .hidden_dim ,
80
- self .ffn_dim ,
81
- bias = False ,
82
- linear_method = linear_method )
83
-
84
- # TODO: Use vllm's SiluAndMul
85
- self .act_fn = nn .SiLU ()
86
-
87
- def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
88
- w1_out , _ = self .w1 (hidden_states )
89
- w1_out = self .act_fn (w1_out )
90
- w3_out , _ = self .w3 (hidden_states )
91
- current_hidden_states = w1_out * w3_out
92
- current_hidden_states , _ = self .w2 (current_hidden_states )
93
- return current_hidden_states
94
-
75
+ tp_size = get_tensor_model_parallel_world_size ()
76
+ self .num_total_experts = num_experts
77
+ self .top_k = top_k
78
+ self .hidden_size = hidden_size
79
+ self .intermediate_size = intermediate_size // tp_size
95
80
96
- class MixtralMoE (nn .Module ):
81
+ if params_dtype is None :
82
+ params_dtype = torch .get_default_dtype ()
83
+ self .params_dtype = params_dtype
97
84
98
- def __init__ (
99
- self ,
100
- config : MixtralConfig ,
101
- linear_method : Optional [LinearMethodBase ] = None ,
102
- ):
103
- super ().__init__ ()
104
- self .config = config
105
- self .rank = get_tensor_model_parallel_rank ()
106
- self .tp_size = get_tensor_model_parallel_world_size ()
107
- self .num_total_experts = config .num_local_experts
108
- self .top_k = config .num_experts_per_tok
109
- if self .tp_size > self .num_total_experts :
110
- raise ValueError (
111
- f"Tensor parallel size { self .tp_size } is greater than "
112
- f"the number of experts { self .num_total_experts } ." )
113
- # Split experts equally between ranks
114
- self .expert_indicies = np .array_split (range (
115
- self .num_total_experts ), self .tp_size )[self .rank ].tolist ()
116
- if not self .expert_indicies :
117
- raise ValueError (
118
- f"Rank { self .rank } has no experts assigned to it." )
119
-
120
- self .experts = nn .ModuleList ([
121
- MixtralMLP (self .num_total_experts ,
122
- config .hidden_size ,
123
- config .intermediate_size ,
124
- linear_method = linear_method )
125
- if idx in self .expert_indicies else None
126
- for idx in range (self .num_total_experts )
127
- ])
128
- self .gate = ReplicatedLinear (config .hidden_size ,
85
+ self .gate = ReplicatedLinear (self .hidden_size ,
129
86
self .num_total_experts ,
130
87
bias = False ,
88
+ params_dtype = self .params_dtype ,
131
89
linear_method = None )
132
90
91
+ self .ws = nn .Parameter (
92
+ torch .empty (self .num_total_experts ,
93
+ 2 * self .intermediate_size ,
94
+ self .hidden_size ,
95
+ device = "cuda" ,
96
+ dtype = self .params_dtype ))
97
+ self .w2s = nn .Parameter (
98
+ torch .empty (self .num_total_experts ,
99
+ self .hidden_size ,
100
+ self .intermediate_size ,
101
+ device = "cuda" ,
102
+ dtype = self .params_dtype ))
103
+
104
+ set_weight_attrs (self .ws , {
105
+ "weight_loader" : self .weight_loader ,
106
+ })
107
+ set_weight_attrs (self .w2s , {
108
+ "weight_loader" : self .weight_loader ,
109
+ })
110
+
111
+ def weight_loader (self , param : nn .Parameter , loaded_weight : torch .Tensor ,
112
+ weight_name : str , expert_id : int ):
113
+ tp_rank = get_tensor_model_parallel_rank ()
114
+ param_data = param .data
115
+ shard_size = self .intermediate_size
116
+ shard = slice (tp_rank * shard_size , (tp_rank + 1 ) * shard_size )
117
+ if weight_name .endswith ("w1.weight" ):
118
+ param_data [expert_id , 0 :shard_size , :] = loaded_weight [shard , :]
119
+ if weight_name .endswith ("w3.weight" ):
120
+ param_data [expert_id ,
121
+ shard_size :2 * shard_size , :] = loaded_weight [shard , :]
122
+ if weight_name .endswith ("w2.weight" ):
123
+ param_data [expert_id , :, :] = loaded_weight [:, shard ]
124
+
133
125
def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
134
- batch_size , sequence_length , hidden_dim = hidden_states .shape
135
- hidden_states = hidden_states .view (- 1 , hidden_dim )
126
+ batch_size , sequence_length , hidden_size = hidden_states .shape
127
+ hidden_states = hidden_states .view (- 1 , self . hidden_size )
136
128
# router_logits: (batch * sequence_length, n_experts)
137
129
router_logits , _ = self .gate (hidden_states )
138
130
@@ -142,22 +134,18 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
142
134
dim = - 1 )
143
135
routing_weights /= routing_weights .sum (dim = - 1 , keepdim = True )
144
136
145
- final_hidden_states = None
146
- for expert_idx in self .expert_indicies :
147
- expert_layer = self .experts [expert_idx ]
148
- expert_mask = (selected_experts == expert_idx )
149
- expert_weights = (routing_weights * expert_mask ).sum (dim = - 1 ,
150
- keepdim = True )
151
-
152
- current_hidden_states = expert_layer (hidden_states ).mul_ (
153
- expert_weights )
154
- if final_hidden_states is None :
155
- final_hidden_states = current_hidden_states
156
- else :
157
- final_hidden_states .add_ (current_hidden_states )
137
+ final_hidden_states = fused_moe (hidden_states ,
138
+ self .ws ,
139
+ self .w2s ,
140
+ routing_weights ,
141
+ selected_experts ,
142
+ inplace = True )
143
+
144
+ final_hidden_states = tensor_model_parallel_all_reduce (
145
+ final_hidden_states )
158
146
159
- return tensor_model_parallel_all_reduce ( final_hidden_states ) .view (
160
- batch_size , sequence_length , hidden_dim )
147
+ return final_hidden_states .view (batch_size , sequence_length ,
148
+ hidden_size )
161
149
162
150
163
151
class MixtralAttention (nn .Module ):
@@ -257,8 +245,11 @@ def __init__(
257
245
rope_theta = rope_theta ,
258
246
sliding_window = config .sliding_window ,
259
247
linear_method = linear_method )
260
- self .block_sparse_moe = MixtralMoE (config = config ,
261
- linear_method = linear_method )
248
+ self .block_sparse_moe = MixtralMoE (
249
+ num_experts = config .num_local_experts ,
250
+ top_k = config .num_experts_per_tok ,
251
+ hidden_size = config .hidden_size ,
252
+ intermediate_size = config .intermediate_size )
262
253
self .input_layernorm = RMSNorm (config .hidden_size ,
263
254
eps = config .rms_norm_eps )
264
255
self .post_attention_layernorm = RMSNorm (config .hidden_size ,
@@ -378,6 +369,14 @@ def load_weights(self,
378
369
("qkv_proj" , "v_proj" , "v" ),
379
370
]
380
371
372
+ expert_params_mapping = [
373
+ # (param_name, weight_name, expert_id)
374
+ ("ws" if weight_name in ["w1" , "w3" ] else "w2s" ,
375
+ f"experts.{ expert_id } .{ weight_name } .weight" , expert_id )
376
+ for expert_id in range (self .config .num_local_experts )
377
+ for weight_name in ["w1" , "w2" , "w3" ]
378
+ ]
379
+
381
380
params_dict = dict (self .named_parameters ())
382
381
for name , loaded_weight in hf_model_weights_iterator (
383
382
model_name_or_path ,
@@ -387,6 +386,7 @@ def load_weights(self,
387
386
fall_back_to_pt = False ):
388
387
if "rotary_emb.inv_freq" in name :
389
388
continue
389
+
390
390
for (param_name , weight_name , shard_id ) in stacked_params_mapping :
391
391
if weight_name not in name :
392
392
continue
@@ -399,14 +399,22 @@ def load_weights(self,
399
399
weight_loader (param , loaded_weight , shard_id )
400
400
break
401
401
else :
402
- # Skip loading extra bias for GPTQ models.
403
- if name .endswith (".bias" ) and name not in params_dict :
404
- continue
405
- # Skip experts that are not assigned to this worker.
406
- if ("block_sparse_moe.experts." in name
407
- and name not in params_dict ):
408
- continue
409
- param = params_dict [name ]
410
- weight_loader = getattr (param , "weight_loader" ,
411
- default_weight_loader )
412
- weight_loader (param , loaded_weight )
402
+ for param_name , weight_name , expert_id in expert_params_mapping :
403
+ if weight_name not in name :
404
+ continue
405
+ name = name .replace (weight_name , param_name )
406
+ param = params_dict [name ]
407
+ weight_loader = param .weight_loader
408
+ weight_loader (param ,
409
+ loaded_weight ,
410
+ weight_name ,
411
+ expert_id = expert_id )
412
+ break
413
+ else :
414
+ # Skip loading extra bias for GPTQ models.
415
+ if name .endswith (".bias" ) and name not in params_dict :
416
+ continue
417
+ param = params_dict [name ]
418
+ weight_loader = getattr (param , "weight_loader" ,
419
+ default_weight_loader )
420
+ weight_loader (param , loaded_weight )
0 commit comments