@@ -131,7 +131,7 @@ def load_lora(name, filename):
131
131
with torch .no_grad ():
132
132
module .weight .copy_ (weight )
133
133
134
- module .to (device = devices .device , dtype = devices .dtype )
134
+ module .to (device = devices .cpu , dtype = devices .dtype )
135
135
136
136
if lora_key == "lora_up.weight" :
137
137
lora_module .up = module
@@ -177,29 +177,69 @@ def load_loras(names, multipliers=None):
177
177
loaded_loras .append (lora )
178
178
179
179
180
- def lora_forward (module , input , res ):
181
- input = devices .cond_cast_unet (input )
182
- if len (loaded_loras ) == 0 :
183
- return res
180
+ def lora_apply_weights (self : torch .nn .Conv2d | torch .nn .Linear ):
181
+ """
182
+ Applies the currently selected set of Loras to the weight of torch layer self.
183
+ If weights already have this particular set of loras applied, does nothing.
184
+ If not, restores orginal weights from backup and alters weights according to loras.
185
+ """
184
186
185
- lora_layer_name = getattr (module , 'lora_layer_name' , None )
186
- for lora in loaded_loras :
187
- module = lora .modules .get (lora_layer_name , None )
188
- if module is not None :
189
- if shared .opts .lora_apply_to_outputs and res .shape == input .shape :
190
- res = res + module .up (module .down (res )) * lora .multiplier * (module .alpha / module .up .weight .shape [1 ] if module .alpha else 1.0 )
191
- else :
192
- res = res + module .up (module .down (input )) * lora .multiplier * (module .alpha / module .up .weight .shape [1 ] if module .alpha else 1.0 )
187
+ current_names = getattr (self , "lora_current_names" , ())
188
+ wanted_names = tuple ((x .name , x .multiplier ) for x in loaded_loras )
189
+
190
+ weights_backup = getattr (self , "lora_weights_backup" , None )
191
+ if weights_backup is None :
192
+ weights_backup = self .weight .to (devices .cpu , copy = True )
193
+ self .lora_weights_backup = weights_backup
194
+
195
+ if current_names != wanted_names :
196
+ if weights_backup is not None :
197
+ self .weight .copy_ (weights_backup )
198
+
199
+ lora_layer_name = getattr (self , 'lora_layer_name' , None )
200
+ for lora in loaded_loras :
201
+ module = lora .modules .get (lora_layer_name , None )
202
+ if module is None :
203
+ continue
193
204
194
- return res
205
+ with torch .no_grad ():
206
+ up = module .up .weight .to (self .weight .device , dtype = self .weight .dtype )
207
+ down = module .down .weight .to (self .weight .device , dtype = self .weight .dtype )
208
+
209
+ if up .shape [2 :] == (1 , 1 ) and down .shape [2 :] == (1 , 1 ):
210
+ updown = (up .squeeze (2 ).squeeze (2 ) @ down .squeeze (2 ).squeeze (2 )).unsqueeze (2 ).unsqueeze (3 )
211
+ else :
212
+ updown = up @ down
213
+
214
+ self .weight += updown * lora .multiplier * (module .alpha / module .up .weight .shape [1 ] if module .alpha else 1.0 )
215
+
216
+ setattr (self , "lora_current_names" , wanted_names )
195
217
196
218
197
219
def lora_Linear_forward (self , input ):
198
- return lora_forward (self , input , torch .nn .Linear_forward_before_lora (self , input ))
220
+ lora_apply_weights (self )
221
+
222
+ return torch .nn .Linear_forward_before_lora (self , input )
223
+
224
+
225
+ def lora_Linear_load_state_dict (self : torch .nn .Linear , * args , ** kwargs ):
226
+ setattr (self , "lora_current_names" , ())
227
+ setattr (self , "lora_weights_backup" , None )
228
+
229
+ return torch .nn .Linear_load_state_dict_before_lora (self , * args , ** kwargs )
199
230
200
231
201
232
def lora_Conv2d_forward (self , input ):
202
- return lora_forward (self , input , torch .nn .Conv2d_forward_before_lora (self , input ))
233
+ lora_apply_weights (self )
234
+
235
+ return torch .nn .Conv2d_forward_before_lora (self , input )
236
+
237
+
238
+ def lora_Conv2d_load_state_dict (self : torch .nn .Conv2d , * args , ** kwargs ):
239
+ setattr (self , "lora_current_names" , ())
240
+ setattr (self , "lora_weights_backup" , None )
241
+
242
+ return torch .nn .Conv2d_load_state_dict_before_lora (self , * args , ** kwargs )
203
243
204
244
205
245
def list_available_loras ():
0 commit comments