Skip to content

Commit 2c11e90

Browse files
committed
repair --medvram for SD2.x too after SDXL update
1 parent 1f26815 commit 2c11e90

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

modules/lowvram.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,12 @@ def first_stage_model_decode_wrap(z):
9090
sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu)
9191
elif is_sd2:
9292
sd_model.cond_stage_model.model.register_forward_pre_hook(send_me_to_gpu)
93+
sd_model.cond_stage_model.model.token_embedding.register_forward_pre_hook(send_me_to_gpu)
94+
parents[sd_model.cond_stage_model.model] = sd_model.cond_stage_model
95+
parents[sd_model.cond_stage_model.model.token_embedding] = sd_model.cond_stage_model
9396
else:
9497
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
98+
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
9599

96100
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
97101
sd_model.first_stage_model.encode = first_stage_model_encode_wrap
@@ -101,9 +105,6 @@ def first_stage_model_decode_wrap(z):
101105
if sd_model.embedder:
102106
sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)
103107

104-
if hasattr(sd_model, 'cond_stage_model'):
105-
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
106-
107108
if use_medvram:
108109
sd_model.model.register_forward_pre_hook(send_me_to_gpu)
109110
else:

modules/sd_hijack_open_clip.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def encode_with_transformers(self, tokens):
3232
def encode_embedding_init_text(self, init_text, nvpt):
3333
ids = tokenizer.encode(init_text)
3434
ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
35-
embedded = self.wrapped.model.token_embedding.wrapped(ids.to(self.wrapped.model.token_embedding.wrapped.weight.device)).squeeze(0)
35+
embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0)
3636

3737
return embedded
3838

0 commit comments

Comments
 (0)