2626from scratchpad .nn .utils import apply_torchao_config_
2727from scratchpad .scheduler .schedule_batch import global_args
2828from scratchpad .model_executor .forward_info import ForwardBatch
29- from triteia .python .nn .linear import sparse_low_precision_linear
30- from triteia .python .ops .matmul .sbmm import sbmm_4bit_2_4_native , sbmm_4bit_2_4_multilaunch , sbmm_4bit_2_4_forloop
3129
3230class LlamaMLP (nn .Module ):
3331 def __init__ (
@@ -61,78 +59,11 @@ def __init__(
6159 self .act_fn = SiluAndMul ()
6260
6361 def forward (self , x ):
64- gate_up , _ = self .gate_up_proj (x )
65- x = self .act_fn (gate_up )
62+ x , _ = self .gate_up_proj (x )
63+ x = self .act_fn (x )
6664 x , _ = self .down_proj (x )
6765 return x
6866
69- class LLamaSBmm (nn .Module ):
70- def __init__ (self , num_experts , infeatures , outfeatures , sbmm_type = "naive" , groupsize = - 1 ):
71- super ().__init__ ()
72- if groupsize == - 1 :
73- groupsize = infeatures
74- self .infeatures = infeatures
75- self .outfeatures = outfeatures
76- self .groupsize = groupsize
77- self .qweight = nn .Parameter (torch .empty ((num_experts , self .infeatures // 32 , self .outfeatures * 16 // 8 ), dtype = torch .int32 ), False )
78- self .meta = nn .Parameter (torch .empty ((num_experts , self .outfeatures , self .infeatures // 16 ), dtype = torch .int16 ), False )
79- self .scales = nn .Parameter (torch .empty ((num_experts , self .infeatures // groupsize , self .outfeatures ), dtype = torch .float16 ), False )
80- self .workspace = nn .Parameter (torch .zeros (num_experts , self .outfeatures // 128 * 16 , dtype = torch .int32 ), False )
81- if sbmm_type == "naive" :
82- self .sbmm_func = sbmm_4bit_2_4_native
83- elif sbmm_type == "multilaunch" :
84- self .sbmm_func = sbmm_4bit_2_4_multilaunch
85- elif sbmm_type == "forloop" :
86- self .sbmm_func = sbmm_4bit_2_4_forloop
87- else :
88- raise NotImplementedError
89-
90- def forward (self , x , indices ):
91- return self .sbmm_func (
92- qweights = self .qweight .data ,
93- xs = x ,
94- metas = self .meta .data ,
95- ss = self .scales .data ,
96- indices = indices )
97-
98-
99- class LlamaCompressedMLP (nn .Module ):
100- def __init__ (
101- self ,
102- hidden_size : int ,
103- intermediate_size : int ,
104- hidden_act : str ,
105- num_experts : int ,
106- sbmm_type : str ,
107- quant_config : Optional [QuantizationConfig ] = None ,
108- prefix : str = "" ,
109- ) -> None :
110- super ().__init__ ()
111- self .intermediate_size = intermediate_size
112- self .hidden_size = hidden_size
113- self .gate_up_proj = LLamaSBmm (
114- num_experts = num_experts ,
115- infeatures = hidden_size ,
116- outfeatures = intermediate_size * 2 ,
117- sbmm_type = sbmm_type ,
118- )
119- self .down_proj = LLamaSBmm (
120- num_experts = num_experts ,
121- infeatures = intermediate_size ,
122- outfeatures = hidden_size ,
123- sbmm_type = sbmm_type ,
124- )
125-
126- def forward (self , x , indices ):
127- # assert not x.isnan().any()
128- gate_up = self .gate_up_proj (x , indices )
129- # assert not gate_up.isnan().any()
130- d = x .shape [- 1 ] // 2
131- x = F .silu (x [..., :d ]) * x [..., d :]
132- # assert not x.isnan().any()
133- x = self .down_proj (x , indices )
134- # assert not x.isnan().any()
135- return x
13667
13768class LlamaMoE (nn .Module ):
13869 def __init__ (
@@ -142,75 +73,58 @@ def __init__(
14273 hidden_act : str ,
14374 num_experts : int ,
14475 experts_per_token : int ,
145- sbmm_type : str ,
14676 quant_config : Optional [QuantizationConfig ] = None ,
14777 prefix : str = "" ,
14878 ) -> None :
14979 super ().__init__ ()
15080 self .experts_per_token = experts_per_token
15181 self .num_experts = num_experts
152- self .base_mlp = LlamaMLP (
153- hidden_size = hidden_size ,
154- intermediate_size = intermediate_size ,
155- hidden_act = hidden_act ,
156- quant_config = quant_config ,
157- prefix = f"{ prefix } .mlp.EXPERT_ID" ,
158- )
159-
160- self .mlp = LlamaCompressedMLP (
161- num_experts = num_experts ,
82+ self .mlp = nn .ModuleList ([
83+ LlamaMLP (
16284 hidden_size = hidden_size ,
16385 intermediate_size = intermediate_size ,
16486 hidden_act = hidden_act ,
16587 quant_config = quant_config ,
166- sbmm_type = sbmm_type ,
167- prefix = f" { prefix } .mlp."
168- )
88+ prefix = f" { prefix } .mlp. { i } "
89+ ) for i in range ( num_experts )
90+ ] )
16991 self .gate = nn .Linear (hidden_size , num_experts , bias = False )
17092
17193 def forward (self , x ):
172- base_y = self .base_mlp (x )
17394 original_shape = x .shape
17495 x = x .view (1 , * x .shape ) if x .dim () == 2 else x
17596 batch_size , sequence_length , hidden_dim = x .shape
176-
17797 x = x .view (- 1 , hidden_dim )
17898 router_logits = self .gate (x )
17999 routing_weights = F .softmax (router_logits , dim = 1 , dtype = torch .float )
180100 routing_weights , selected_experts = torch .topk (routing_weights , self .experts_per_token , dim = - 1 )
181101 routing_weights /= routing_weights .sum (dim = - 1 , keepdim = True )
182- routing_weights = routing_weights .to (x .dtype ). T
102+ routing_weights = routing_weights .to (x .dtype )
183103 final_hidden_states = torch .zeros (
184104 (batch_size * sequence_length , hidden_dim ), dtype = x .dtype , device = x .device
185105 )
186- sort_selected_experts , argsort_selected_experts = torch .sort (selected_experts .T , dim = - 1 )
187- for k in range (self .experts_per_token ):
188- current_selected_experts = sort_selected_experts [k ]
189- current_routing_weights = routing_weights [k ].view (- 1 , 1 )
190- current_argsort_selected_experts = argsort_selected_experts [k ]
191- sort_x = x [current_argsort_selected_experts ]
192- current_hidden_states = self .mlp (sort_x , current_selected_experts )[current_argsort_selected_experts ] * current_routing_weights
193- final_hidden_states += current_hidden_states
194-
106+ expert_mask = torch .nn .functional .one_hot (selected_experts , num_classes = self .num_experts ).permute (2 , 1 , 0 ).contiguous ()
107+
108+ for expert_idx in range (self .num_experts ):
109+ expert_layer = self .mlp [expert_idx ]
110+ current_mask = expert_mask [expert_idx ]
111+ idx , top_x = torch .where (current_mask )
112+ current_state = x [None , top_x ].reshape (- 1 , hidden_dim )
113+ if current_state .nelement () != 0 :
114+ current_hidden_states = expert_layer (current_state )
115+ current_hidden_states *= routing_weights [top_x , idx , None ]
116+ final_hidden_states .index_add_ (0 , top_x , current_hidden_states .to (final_hidden_states .dtype ))
195117 final_hidden_states = final_hidden_states .reshape (batch_size , sequence_length , hidden_dim )
196118 final_hidden_states = final_hidden_states .view (original_shape )
197-
198- final_hidden_states = final_hidden_states .contiguous ()
199- base_y = base_y .contiguous ()
200-
201119 # For debugging
202120 # assert final_hidden_states.is_contiguous(), "final_hidden_states is not contiguous"
203- # assert base_y.is_contiguous(), "base_y is not contiguous"
204- # assert final_hidden_states.device == base_y.device, "Tensors are on different devices"
205- # assert final_hidden_states.dtype == base_y.dtype, "Tensors have different dtypes"
121+ # print(final_hidden_states.device)
122+ # print(final_hidden_states.shape)
123+ # print(final_hidden_states.dtype)
124+ # print(final_hidden_states)
206125 # assert not torch.isnan(final_hidden_states).any(), "NaN found in final_hidden_states"
207- # assert not torch.isnan(base_y).any(), "NaN found in base_y"
208126 # assert not torch.isinf(final_hidden_states).any(), "Inf found in final_hidden_states"
209- # assert not torch.isinf(base_y).any(), "Inf found in base_y"
210- # assert final_hidden_states.shape == base_y.shape, "Tensors have different shapes"
211- # torch.cuda.synchronize()
212- result = final_hidden_states + base_y
213- return result
127+ return final_hidden_states
214128
215129class LlamaAttention (nn .Module ):
216130 def __init__ (
@@ -340,7 +254,6 @@ def __init__(
340254 quant_config = quant_config ,
341255 num_experts = config .num_experts ,
342256 experts_per_token = config .experts_per_token ,
343- sbmm_type = config .sbmm_type ,
344257 prefix = f"{ prefix } .moe" ,
345258 )
346259 self .input_layernorm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
@@ -505,13 +418,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
505418 ]
506419 params_dict = dict (self .named_parameters ())
507420
508- name_transformations = [
509- ("down_proj.0" , "down_proj" ),
510- ("gate_up_proj.0" , "gate_up_proj" ),
511- ("mlp.EXPERT_ID" , "base_mlp" )
512- ]
513421 for name , loaded_weight in weights :
514- assert not loaded_weight .isnan ().any ()
422+ # print(name)
423+ # assert not loaded_weight.isnan().any()
515424 # continue
516425 if "rotary_emb.inv_freq" in name or "projector" in name :
517426 continue
@@ -525,28 +434,19 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
525434 for param_name , weight_name , shard_id in stacked_params_mapping :
526435 if weight_name not in name :
527436 continue
437+ print (name , name .replace (weight_name , param_name ), shard_id )
528438 name = name .replace (weight_name , param_name )
529439 # Skip loading extra bias for GPTQ models.
530440 if name .endswith (".bias" ) and name not in params_dict :
531441 continue
532- for transformation in name_transformations :
533- if transformation [0 ] in name :
534- name = name .replace (transformation [0 ], transformation [1 ])
535442 param = params_dict [name ]
536443 weight_loader = param .weight_loader
537444 weight_loader (param , loaded_weight , shard_id )
538445 break
539446 else :
540- if name == "lm_head.0.weight" :
541- continue
542- if name == "model.embed_tokens.0.weight" :
543- continue
544447 # Skip loading extra bias for GPTQ models.
545448 if name .endswith (".bias" ) and name not in params_dict :
546449 continue
547- for transformation in name_transformations :
548- if transformation [0 ] in name :
549- name = name .replace (transformation [0 ], transformation [1 ])
550450 param = params_dict [name ]
551451 weight_loader = getattr (param , "weight_loader" , default_weight_loader )
552452 weight_loader (param , loaded_weight )
0 commit comments