11import json
2- from collections import defaultdict
32from dataclasses import dataclass , field
43from typing import Any
54
65import torch
76
87from invokeai .backend .patches .layers .base_layer_patch import BaseLayerPatch
98from invokeai .backend .patches .layers .utils import any_lora_layer_from_state_dict
9+ from invokeai .backend .patches .lora_conversions .flux_diffusers_lora_conversion_utils import _group_by_layer
1010from invokeai .backend .patches .lora_conversions .flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
1111from invokeai .backend .patches .model_patch_raw import ModelPatchRaw
1212from invokeai .backend .util import InvokeAILogger
@@ -25,11 +25,11 @@ def is_state_dict_likely_in_flux_aitoolkit_format(state_dict: dict[str, Any], me
2525
2626@dataclass
2727class GroupedStateDict :
28- transformer : dict = field (default_factory = dict )
28+ transformer : dict [ str , Any ] = field (default_factory = dict )
2929 # might also grow CLIP and T5 submodels
3030
3131
32- def _group_state_by_submodel (state_dict : dict [str , torch . Tensor ]) -> GroupedStateDict :
32+ def _group_state_by_submodel (state_dict : dict [str , Any ]) -> GroupedStateDict :
3333 logger = InvokeAILogger .get_logger ()
3434 grouped = GroupedStateDict ()
3535 for key , value in state_dict .items ():
@@ -42,11 +42,22 @@ def _group_state_by_submodel(state_dict: dict[str, torch.Tensor]) -> GroupedStat
4242 return grouped
4343
4444
45+ def _rename_peft_lora_keys (state_dict : dict [str , torch .Tensor ]) -> dict [str , torch .Tensor ]:
46+ """Renames keys from the PEFT LoRA format to the InvokeAI format."""
47+ renamed_state_dict = {}
48+ for key , value in state_dict .items ():
49+ renamed_key = key .replace (".lora_A." , ".lora_down." ).replace (".lora_B." , ".lora_up." )
50+ renamed_state_dict [renamed_key ] = value
51+ return renamed_state_dict
52+
53+
4554def lora_model_from_flux_aitoolkit_state_dict (state_dict : dict [str , torch .Tensor ]) -> ModelPatchRaw :
46- grouped = _group_state_by_submodel (state_dict )
55+ state_dict = _rename_peft_lora_keys (state_dict )
56+ by_layer = _group_by_layer (state_dict )
57+ by_model = _group_state_by_submodel (by_layer )
4758
4859 layers : dict [str , BaseLayerPatch ] = {}
49- for layer_key , layer_state_dict in grouped .transformer .items ():
60+ for layer_key , layer_state_dict in by_model .transformer .items ():
5061 layers [FLUX_LORA_TRANSFORMER_PREFIX + layer_key ] = any_lora_layer_from_state_dict (layer_state_dict )
5162
5263 return ModelPatchRaw (layers = layers )
0 commit comments