@@ -124,7 +124,6 @@ def init_toppings(self):
124124 self .origin_target_modules = set ()
125125 for name , top in self .available_toppings .items ():
126126 self .configs [name ] = ToppingConfig (topping_type = top [0 ], path = top [1 ])
127-
128127 self .origin_target_modules = set (self .origin_target_modules ) | set (
129128 self .configs [name ].hf_config ["target_modules" ]
130129 )
@@ -134,6 +133,8 @@ def init_toppings(self):
134133 self .base_model .get_module_name (module )
135134 for module in self .origin_target_modules
136135 }
136+ # remove down_proj from target modules
137+ logger .info (f"Target modules: { self .target_modules } " )
137138 else :
138139 logger .warning (
139140 "WARNING: get_module_name() is not defined, "
@@ -166,6 +167,7 @@ def init_toppings(self):
166167 self .lora_id = {}
167168 self .deltas : List [DeltaAdapter ] = []
168169 self .delta_id = {}
170+
169171 for name in self .available_toppings .keys ():
170172 t_type = self .available_toppings [name ][0 ]
171173 logger .info (f"Loading { t_type } { name } " )
@@ -189,10 +191,15 @@ def init_toppings(self):
189191 self .deltas [- 1 ].initialize_weights ()
190192
191193 # misc lora configs
192- self .max_lora_dim = max (
193- [x .hf_config ["r" ] for x in self .configs .values () if "r" in x .hf_config ]
194- )
195- self .scaling = self .loras [0 ].scaling
194+ self .max_lora_dim = [
195+ x .hf_config ["r" ] for x in self .configs .values () if "r" in x .hf_config
196+ ]
197+ if len (self .max_lora_dim ) == 0 :
198+ self .max_lora_dim = 0
199+ self .scaling = 0
200+ else :
201+ self .max_lora_dim = max (self .max_lora_dim )
202+ self .scaling = self .loras [0 ].scaling
196203 # FIXME remove the restrictions
197204 assert all (
198205 x .hf_config ["r" ] == self .max_lora_dim
@@ -215,6 +222,9 @@ def print_available_toppings(self):
215222 def set_topping_module (self , module_name , module ):
216223 topping_module = get_topping_layer (module )
217224 replace_submodule (self .base_model , module_name , topping_module )
225+ logger .info (
226+ f"Replaced { module_name } with topping module { type (topping_module )} "
227+ )
218228 return topping_module
219229
220230 def prepare_topping_batch (self , forward_batch : ForwardBatch ):
@@ -288,7 +298,6 @@ def prepare_topping_batch(self, forward_batch: ForwardBatch):
288298 dtype = torch .int64 ,
289299 device = forward_batch .input_ids .device ,
290300 )
291- print (f"weight_indices: { weight_indices } " )
292301 for module_name , module in self .topping_modules :
293302 layer_id = get_layer_id (module_name )
294303 if "lm_head" in module_name :
@@ -327,6 +336,42 @@ def prepare_topping_batch(self, forward_batch: ForwardBatch):
327336 self .scales_buffer ["kv_proj" ][layer_id ][:len_deltas ],
328337 ),
329338 )
339+ elif "down_proj" in module_name :
340+ weight_name = self .get_weight_name (module_name , 0 )
341+ module .set_topping_info (
342+ bs ,
343+ weight_indices ,
344+ lora_buffer = (
345+ (
346+ self .A_buffer [weight_name ][layer_id ][:len_loras ]
347+ if weight_name in self .A_buffer
348+ else None
349+ ),
350+ (
351+ self .B_buffer [weight_name ][layer_id ][:len_loras ]
352+ if weight_name in self .B_buffer
353+ else None
354+ ),
355+ ),
356+ delta_buffer = (
357+ (
358+ self .qweight_buffer [weight_name ][layer_id ][:len_deltas ]
359+ if weight_name in self .qweight_buffer
360+ else None
361+ ),
362+ (
363+ self .meta_buffer [weight_name ][layer_id ][:len_deltas ]
364+ if weight_name in self .meta_buffer
365+ else None
366+ ),
367+ (
368+ self .scales_buffer [weight_name ][layer_id ][:len_deltas ]
369+ if weight_name in self .scales_buffer
370+ else None
371+ ),
372+ ),
373+ debug = False ,
374+ )
330375 else :
331376 weight_name = self .get_weight_name (module_name , 0 )
332377 module .set_topping_info (
@@ -375,6 +420,7 @@ def load_topping(self, uid, buffer_id):
375420 """
376421 This function loads topping from CPU -> GPU memory
377422 """
423+
378424 if uid not in self .available_toppings :
379425 logger .error (f"Topping { uid } not registered" )
380426 return
@@ -420,6 +466,7 @@ def _load_delta(self, uid, buffer_id):
420466
421467 for i in range (num_layer ):
422468 layer_weights = self .deltas [self .delta_id [uid ]].layers [i ].weights
469+ # load to buffer space
423470 for name , weights in layer_weights .items ():
424471 if (
425472 "qkv_proj" in name
@@ -445,7 +492,7 @@ def _load_delta(self, uid, buffer_id):
445492 self .scales_buffer [kv_proj_name ][i ][buffer_id ].copy_ (
446493 weights [:, q_dim :]
447494 )
448- else :
495+ elif "meta" in name :
449496 q_proj_name = "q_proj"
450497 kv_proj_name = "kv_proj"
451498 q_dim = self .meta_buffer [q_proj_name ][i ][buffer_id ].shape [0 ]
@@ -455,23 +502,30 @@ def _load_delta(self, uid, buffer_id):
455502 self .meta_buffer [kv_proj_name ][i ][buffer_id ].copy_ (
456503 weights [q_dim :, :]
457504 )
505+ else :
506+ print ("Unknown delta weight name: {name}" )
458507 else :
459508 if "qweight" in name :
460509 weight_name = self .get_delta_weight_name (name )
461510 if weight_name :
462511 self .qweight_buffer [weight_name ][i ][buffer_id ].copy_ (
463512 weights
464513 )
514+ else :
515+ print ("Unknown delta weight name: {name}" )
516+
465517 elif "scales" in name :
466518 weight_name = self .get_delta_weight_name (name )
467519 if weight_name :
468520 self .scales_buffer [weight_name ][i ][buffer_id ].copy_ (weights )
521+
469522 elif "meta" in name :
470523 weight_name = self .get_delta_weight_name (name )
471524 if weight_name :
472525 self .meta_buffer [weight_name ][i ][buffer_id ].copy_ (weights )
473526 else :
474527 print ("Unknown delta weight name: {name}" )
528+ raise ValueError (f"Unknown delta weight name: { name } " )
475529
476530 for name , outside_module in self .deltas [
477531 self .delta_id [uid ]
0 commit comments