2727from scratchpad .scheduler .schedule_batch import global_args
2828from scratchpad .model_executor .forward_info import ForwardBatch
2929from triteia .python .nn .linear import sparse_low_precision_linear
30+ from triteia .python .ops .matmul .sbmm import sbmm_4bit_2_4_native , sbmm_4bit_2_4_multilaunch , sbmm_4bit_2_4_forloop
3031
3132class LlamaMLP (nn .Module ):
3233 def __init__ (
@@ -65,37 +66,72 @@ def forward(self, x):
6566 x , _ = self .down_proj (x )
6667 return x
6768
69+ class LLamaSBmm (nn .Module ):
70+ def __init__ (self , num_experts , infeatures , outfeatures , sbmm_type = "naive" , groupsize = - 1 ):
71+ super ().__init__ ()
72+ if groupsize == - 1 :
73+ groupsize = infeatures
74+ self .infeatures = infeatures
75+ self .outfeatures = outfeatures
76+ self .groupsize = groupsize
77+ self .qweight = nn .Parameter (torch .empty ((num_experts , self .infeatures // 32 , self .outfeatures * 16 // 8 ), dtype = torch .int32 ), False )
78+ self .meta = nn .Parameter (torch .empty ((num_experts , self .outfeatures , self .infeatures // 16 ), dtype = torch .int16 ), False )
79+ self .scales = nn .Parameter (torch .empty ((num_experts , self .infeatures // groupsize , self .outfeatures ), dtype = torch .float16 ), False )
80+ self .workspace = nn .Parameter (torch .zeros (num_experts , self .outfeatures // 128 * 16 , dtype = torch .int32 ), False )
81+ if sbmm_type == "naive" :
82+ self .sbmm_func = sbmm_4bit_2_4_native
83+ elif sbmm_type == "multilaunch" :
84+ self .sbmm_func = sbmm_4bit_2_4_multilaunch
85+ elif sbmm_type == "forloop" :
86+ self .sbmm_func = sbmm_4bit_2_4_forloop
87+ else :
88+ raise NotImplementedError
89+
90+ def forward (self , x , indices ):
91+ return self .sbmm_func (
92+ qweights = self .qweight .data ,
93+ xs = x ,
94+ metas = self .meta .data ,
95+ ss = self .scales .data ,
96+ indices = indices )
97+
6898
6999class LlamaCompressedMLP (nn .Module ):
70100 def __init__ (
71101 self ,
72102 hidden_size : int ,
73103 intermediate_size : int ,
74104 hidden_act : str ,
105+ num_experts : int ,
106+ sbmm_type : str ,
75107 quant_config : Optional [QuantizationConfig ] = None ,
76108 prefix : str = "" ,
77109 ) -> None :
78110 super ().__init__ ()
79111 self .intermediate_size = intermediate_size
80112 self .hidden_size = hidden_size
81- self .gate_up_proj = sparse_low_precision_linear (
82- hidden_size ,
83- intermediate_size * 2 ,
113+ self .gate_up_proj = LLamaSBmm (
114+ num_experts = num_experts ,
115+ infeatures = hidden_size ,
116+ outfeatures = intermediate_size * 2 ,
117+ sbmm_type = sbmm_type ,
84118 )
85- self .down_proj = sparse_low_precision_linear (
86- intermediate_size ,
87- hidden_size ,
119+ self .down_proj = LLamaSBmm (
120+ num_experts = num_experts ,
121+ infeatures = intermediate_size ,
122+ outfeatures = hidden_size ,
123+ sbmm_type = sbmm_type ,
88124 )
89125
90- def forward (self , x ):
91- assert not x .isnan ().any ()
92- gate_up = self .gate_up_proj (x )
93- assert not gate_up .isnan ().any ()
126+ def forward (self , x , indices ):
127+ # assert not x.isnan().any()
128+ gate_up = self .gate_up_proj (x , indices )
129+ # assert not gate_up.isnan().any()
94130 d = x .shape [- 1 ] // 2
95131 x = F .silu (x [..., :d ]) * x [..., d :]
96- assert not x .isnan ().any ()
97- x = self .down_proj (x )
98- assert not x .isnan ().any ()
132+ # assert not x.isnan().any()
133+ x = self .down_proj (x , indices )
134+ # assert not x.isnan().any()
99135 return x
100136
101137class LlamaMoE (nn .Module ):
@@ -106,6 +142,7 @@ def __init__(
106142 hidden_act : str ,
107143 num_experts : int ,
108144 experts_per_token : int ,
145+ sbmm_type : str ,
109146 quant_config : Optional [QuantizationConfig ] = None ,
110147 prefix : str = "" ,
111148 ) -> None :
@@ -119,49 +156,42 @@ def __init__(
119156 quant_config = quant_config ,
120157 prefix = f"{ prefix } .mlp.EXPERT_ID" ,
121158 )
122- self .mlp = nn .ModuleList ([
123- LlamaCompressedMLP (
159+
160+ self .mlp = LlamaCompressedMLP (
161+ num_experts = num_experts ,
124162 hidden_size = hidden_size ,
125163 intermediate_size = intermediate_size ,
126164 hidden_act = hidden_act ,
127165 quant_config = quant_config ,
128- prefix = f" { prefix } .mlp. { i } "
129- ) for i in range ( num_experts )
130- ] )
166+ sbmm_type = sbmm_type ,
167+ prefix = f" { prefix } .mlp."
168+ )
131169 self .gate = nn .Linear (hidden_size , num_experts , bias = False )
132170
133171 def forward (self , x ):
134-
135172 base_y = self .base_mlp (x )
136173 original_shape = x .shape
137174 x = x .view (1 , * x .shape ) if x .dim () == 2 else x
138175 batch_size , sequence_length , hidden_dim = x .shape
176+
139177 x = x .view (- 1 , hidden_dim )
140178 router_logits = self .gate (x )
141-
142179 routing_weights = F .softmax (router_logits , dim = 1 , dtype = torch .float )
143180 routing_weights , selected_experts = torch .topk (routing_weights , self .experts_per_token , dim = - 1 )
144181 routing_weights /= routing_weights .sum (dim = - 1 , keepdim = True )
145- # we cast back to the input dtype
146- routing_weights = routing_weights .to (x .dtype )
147- assert not routing_weights .isnan ().any (), "routing weights have nan"
182+ routing_weights = routing_weights .to (x .dtype ).T
148183 final_hidden_states = torch .zeros (
149184 (batch_size * sequence_length , hidden_dim ), dtype = x .dtype , device = x .device
150185 )
151- expert_mask = torch .nn .functional .one_hot (selected_experts , num_classes = self .num_experts ).permute (2 , 1 , 0 ).contiguous ()
152-
153- for expert_idx in range (self .num_experts ):
154- expert_layer = self .mlp [expert_idx ]
155- current_mask = expert_mask
156- current_mask = current_mask [expert_idx ]
157- idx , top_x = torch .where (current_mask )
158- current_state = x [None , top_x ].reshape (- 1 , hidden_dim )
159- assert not torch .isnan (current_state ).any (), "current input state has nan"
160- current_hidden_states = expert_layer (current_state )
161- assert not torch .isnan (current_hidden_states ).any (), "current hidden state has nan"
162- current_hidden_states *= routing_weights [top_x , idx , None ]
163- if current_hidden_states .nelement () != 0 :
164- final_hidden_states .index_add_ (0 , top_x , current_hidden_states .to (x .dtype ))
186+ sort_selected_experts , argsort_selected_experts = torch .sort (selected_experts .T , dim = - 1 )
187+ for k in range (self .experts_per_token ):
188+ current_selected_experts = sort_selected_experts [k ]
189+ current_routing_weights = routing_weights [k ].view (- 1 , 1 )
190+ current_argsort_selected_experts = argsort_selected_experts [k ]
191+ sort_x = x [current_argsort_selected_experts ]
192+ current_hidden_states = self .mlp (sort_x , current_selected_experts )[current_argsort_selected_experts ] * current_routing_weights
193+ final_hidden_states += current_hidden_states
194+
165195 final_hidden_states = final_hidden_states .reshape (batch_size , sequence_length , hidden_dim )
166196 final_hidden_states = final_hidden_states .view (original_shape )
167197
@@ -310,6 +340,7 @@ def __init__(
310340 quant_config = quant_config ,
311341 num_experts = config .num_experts ,
312342 experts_per_token = config .experts_per_token ,
343+ sbmm_type = config .sbmm_type ,
313344 prefix = f"{ prefix } .moe" ,
314345 )
315346 self .input_layernorm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
@@ -480,7 +511,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
480511 ("mlp.EXPERT_ID" , "base_mlp" )
481512 ]
482513 for name , loaded_weight in weights :
483- # print(name)
484514 assert not loaded_weight .isnan ().any ()
485515 # continue
486516 if "rotary_emb.inv_freq" in name or "projector" in name :
0 commit comments