Skip to content

Commit 80b26d2

Browse files
committed
apply Lora by altering layer's weights instead of adding more calculations in forward()
1 parent 69eb2a9 commit 80b26d2

File tree

2 files changed

+66
-18
lines changed

2 files changed

+66
-18
lines changed

extensions-builtin/Lora/lora.py

Lines changed: 56 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def load_lora(name, filename):
131131
with torch.no_grad():
132132
module.weight.copy_(weight)
133133

134-
module.to(device=devices.device, dtype=devices.dtype)
134+
module.to(device=devices.cpu, dtype=devices.dtype)
135135

136136
if lora_key == "lora_up.weight":
137137
lora_module.up = module
@@ -177,29 +177,69 @@ def load_loras(names, multipliers=None):
177177
loaded_loras.append(lora)
178178

179179

180-
def lora_forward(module, input, res):
181-
input = devices.cond_cast_unet(input)
182-
if len(loaded_loras) == 0:
183-
return res
180+
def lora_apply_weights(self: torch.nn.Conv2d | torch.nn.Linear):
181+
"""
182+
Applies the currently selected set of Loras to the weight of torch layer self.
183+
If weights already have this particular set of loras applied, does nothing.
184+
If not, restores orginal weights from backup and alters weights according to loras.
185+
"""
184186

185-
lora_layer_name = getattr(module, 'lora_layer_name', None)
186-
for lora in loaded_loras:
187-
module = lora.modules.get(lora_layer_name, None)
188-
if module is not None:
189-
if shared.opts.lora_apply_to_outputs and res.shape == input.shape:
190-
res = res + module.up(module.down(res)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
191-
else:
192-
res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
187+
current_names = getattr(self, "lora_current_names", ())
188+
wanted_names = tuple((x.name, x.multiplier) for x in loaded_loras)
189+
190+
weights_backup = getattr(self, "lora_weights_backup", None)
191+
if weights_backup is None:
192+
weights_backup = self.weight.to(devices.cpu, copy=True)
193+
self.lora_weights_backup = weights_backup
194+
195+
if current_names != wanted_names:
196+
if weights_backup is not None:
197+
self.weight.copy_(weights_backup)
198+
199+
lora_layer_name = getattr(self, 'lora_layer_name', None)
200+
for lora in loaded_loras:
201+
module = lora.modules.get(lora_layer_name, None)
202+
if module is None:
203+
continue
193204

194-
return res
205+
with torch.no_grad():
206+
up = module.up.weight.to(self.weight.device, dtype=self.weight.dtype)
207+
down = module.down.weight.to(self.weight.device, dtype=self.weight.dtype)
208+
209+
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
210+
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
211+
else:
212+
updown = up @ down
213+
214+
self.weight += updown * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
215+
216+
setattr(self, "lora_current_names", wanted_names)
195217

196218

197219
def lora_Linear_forward(self, input):
198-
return lora_forward(self, input, torch.nn.Linear_forward_before_lora(self, input))
220+
lora_apply_weights(self)
221+
222+
return torch.nn.Linear_forward_before_lora(self, input)
223+
224+
225+
def lora_Linear_load_state_dict(self: torch.nn.Linear, *args, **kwargs):
226+
setattr(self, "lora_current_names", ())
227+
setattr(self, "lora_weights_backup", None)
228+
229+
return torch.nn.Linear_load_state_dict_before_lora(self, *args, **kwargs)
199230

200231

201232
def lora_Conv2d_forward(self, input):
202-
return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora(self, input))
233+
lora_apply_weights(self)
234+
235+
return torch.nn.Conv2d_forward_before_lora(self, input)
236+
237+
238+
def lora_Conv2d_load_state_dict(self: torch.nn.Conv2d, *args, **kwargs):
239+
setattr(self, "lora_current_names", ())
240+
setattr(self, "lora_weights_backup", None)
241+
242+
return torch.nn.Conv2d_load_state_dict_before_lora(self, *args, **kwargs)
203243

204244

205245
def list_available_loras():

extensions-builtin/Lora/scripts/lora_script.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99

1010
def unload():
1111
torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora
12+
torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora
1213
torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora
14+
torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_lora
1315

1416

1517
def before_ui():
@@ -20,11 +22,19 @@ def before_ui():
2022
if not hasattr(torch.nn, 'Linear_forward_before_lora'):
2123
torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward
2224

25+
if not hasattr(torch.nn, 'Linear_load_state_dict_before_lora'):
26+
torch.nn.Linear_load_state_dict_before_lora = torch.nn.Linear._load_from_state_dict
27+
2328
if not hasattr(torch.nn, 'Conv2d_forward_before_lora'):
2429
torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward
2530

31+
if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_lora'):
32+
torch.nn.Conv2d_load_state_dict_before_lora = torch.nn.Conv2d._load_from_state_dict
33+
2634
torch.nn.Linear.forward = lora.lora_Linear_forward
35+
torch.nn.Linear._load_from_state_dict = lora.lora_Linear_load_state_dict
2736
torch.nn.Conv2d.forward = lora.lora_Conv2d_forward
37+
torch.nn.Conv2d._load_from_state_dict = lora.lora_Conv2d_load_state_dict
2838

2939
script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
3040
script_callbacks.on_script_unloaded(unload)
@@ -33,6 +43,4 @@ def before_ui():
3343

3444
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
3545
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
36-
"lora_apply_to_outputs": shared.OptionInfo(False, "Apply Lora to outputs rather than inputs when possible (experimental)"),
37-
3846
}))

0 commit comments

Comments
 (0)