Skip to content

Commit a9f04ed

Browse files
Implement text encoder part of HunyuanDiT loras.
1 parent a475ec2 commit a9f04ed

File tree

1 file changed

+29
-5
lines changed

1 file changed

+29
-5
lines changed

comfy/lora.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,21 @@
1+
"""
2+
This file is part of ComfyUI.
3+
Copyright (C) 2024 Comfy
4+
5+
This program is free software: you can redistribute it and/or modify
6+
it under the terms of the GNU General Public License as published by
7+
the Free Software Foundation, either version 3 of the License, or
8+
(at your option) any later version.
9+
10+
This program is distributed in the hope that it will be useful,
11+
but WITHOUT ANY WARRANTY; without even the implied warranty of
12+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13+
GNU General Public License for more details.
14+
15+
You should have received a copy of the GNU General Public License
16+
along with this program. If not, see <https://www.gnu.org/licenses/>.
17+
"""
18+
119
import comfy.utils
220
import logging
321

@@ -218,11 +236,17 @@ def model_lora_keys_clip(model, key_map={}):
218236
lora_key = "lora_prior_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #cascade lora: TODO put lora key prefix in the model config
219237
key_map[lora_key] = k
220238

221-
for k in sdk: #OneTrainer SD3 lora
222-
if k.startswith("t5xxl.transformer.") and k.endswith(".weight"):
223-
l_key = k[len("t5xxl.transformer."):-len(".weight")]
224-
lora_key = "lora_te3_{}".format(l_key.replace(".", "_"))
225-
key_map[lora_key] = k
239+
for k in sdk:
240+
if k.endswith(".weight"):
241+
if k.startswith("t5xxl.transformer."):#OneTrainer SD3 lora
242+
l_key = k[len("t5xxl.transformer."):-len(".weight")]
243+
lora_key = "lora_te3_{}".format(l_key.replace(".", "_"))
244+
key_map[lora_key] = k
245+
elif k.startswith("hydit_clip.transformer.bert."): #HunyuanDiT Lora
246+
l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")]
247+
lora_key = "lora_te1_{}".format(l_key.replace(".", "_"))
248+
key_map[lora_key] = k
249+
226250

227251
k = "clip_g.transformer.text_projection.weight"
228252
if k in sdk:

0 commit comments

Comments
 (0)