|
| 1 | +""" |
| 2 | + This file is part of ComfyUI. |
| 3 | + Copyright (C) 2024 Comfy |
| 4 | +
|
| 5 | + This program is free software: you can redistribute it and/or modify |
| 6 | + it under the terms of the GNU General Public License as published by |
| 7 | + the Free Software Foundation, either version 3 of the License, or |
| 8 | + (at your option) any later version. |
| 9 | +
|
| 10 | + This program is distributed in the hope that it will be useful, |
| 11 | + but WITHOUT ANY WARRANTY; without even the implied warranty of |
| 12 | + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
| 13 | + GNU General Public License for more details. |
| 14 | +
|
| 15 | + You should have received a copy of the GNU General Public License |
| 16 | + along with this program. If not, see <https://www.gnu.org/licenses/>. |
| 17 | +""" |
| 18 | + |
1 | 19 | import comfy.utils |
2 | 20 | import logging |
3 | 21 |
|
@@ -218,11 +236,17 @@ def model_lora_keys_clip(model, key_map={}): |
218 | 236 | lora_key = "lora_prior_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #cascade lora: TODO put lora key prefix in the model config |
219 | 237 | key_map[lora_key] = k |
220 | 238 |
|
221 | | - for k in sdk: #OneTrainer SD3 lora |
222 | | - if k.startswith("t5xxl.transformer.") and k.endswith(".weight"): |
223 | | - l_key = k[len("t5xxl.transformer."):-len(".weight")] |
224 | | - lora_key = "lora_te3_{}".format(l_key.replace(".", "_")) |
225 | | - key_map[lora_key] = k |
| 239 | + for k in sdk: |
| 240 | + if k.endswith(".weight"): |
| 241 | + if k.startswith("t5xxl.transformer."):#OneTrainer SD3 lora |
| 242 | + l_key = k[len("t5xxl.transformer."):-len(".weight")] |
| 243 | + lora_key = "lora_te3_{}".format(l_key.replace(".", "_")) |
| 244 | + key_map[lora_key] = k |
| 245 | + elif k.startswith("hydit_clip.transformer.bert."): #HunyuanDiT Lora |
| 246 | + l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")] |
| 247 | + lora_key = "lora_te1_{}".format(l_key.replace(".", "_")) |
| 248 | + key_map[lora_key] = k |
| 249 | + |
226 | 250 |
|
227 | 251 | k = "clip_g.transformer.text_projection.weight" |
228 | 252 | if k in sdk: |
|
0 commit comments