35
35
)
36
36
from .deepfloyd_if import IFImg2ImgPipeline , IFInpaintingPipeline , IFPipeline
37
37
from .flux import (
38
+ FluxControlImg2ImgPipeline ,
39
+ FluxControlInpaintPipeline ,
38
40
FluxControlNetImg2ImgPipeline ,
39
41
FluxControlNetInpaintPipeline ,
40
42
FluxControlNetPipeline ,
43
+ FluxControlPipeline ,
41
44
FluxImg2ImgPipeline ,
42
45
FluxInpaintPipeline ,
43
46
FluxPipeline ,
125
128
("pixart-sigma-pag" , PixArtSigmaPAGPipeline ),
126
129
("auraflow" , AuraFlowPipeline ),
127
130
("flux" , FluxPipeline ),
131
+ ("flux-control" , FluxControlPipeline ),
128
132
("flux-controlnet" , FluxControlNetPipeline ),
129
133
("lumina" , LuminaText2ImgPipeline ),
130
134
("cogview3" , CogView3PlusPipeline ),
150
154
("lcm" , LatentConsistencyModelImg2ImgPipeline ),
151
155
("flux" , FluxImg2ImgPipeline ),
152
156
("flux-controlnet" , FluxControlNetImg2ImgPipeline ),
157
+ ("flux-control" , FluxControlImg2ImgPipeline ),
153
158
]
154
159
)
155
160
168
173
("stable-diffusion-xl-pag" , StableDiffusionXLPAGInpaintPipeline ),
169
174
("flux" , FluxInpaintPipeline ),
170
175
("flux-controlnet" , FluxControlNetInpaintPipeline ),
176
+ ("flux-control" , FluxControlInpaintPipeline ),
171
177
("stable-diffusion-pag" , StableDiffusionPAGInpaintPipeline ),
172
178
]
173
179
)
@@ -401,16 +407,20 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
401
407
402
408
config = cls .load_config (pretrained_model_or_path , ** load_config_kwargs )
403
409
orig_class_name = config ["_class_name" ]
410
+ if "ControlPipeline" in orig_class_name :
411
+ to_replace = "ControlPipeline"
412
+ else :
413
+ to_replace = "Pipeline"
404
414
405
415
if "controlnet" in kwargs :
406
416
if isinstance (kwargs ["controlnet" ], ControlNetUnionModel ):
407
- orig_class_name = config ["_class_name" ].replace ("Pipeline" , "ControlNetUnionPipeline" )
417
+ orig_class_name = config ["_class_name" ].replace (to_replace , "ControlNetUnionPipeline" )
408
418
else :
409
- orig_class_name = config ["_class_name" ].replace ("Pipeline" , "ControlNetPipeline" )
419
+ orig_class_name = config ["_class_name" ].replace (to_replace , "ControlNetPipeline" )
410
420
if "enable_pag" in kwargs :
411
421
enable_pag = kwargs .pop ("enable_pag" )
412
422
if enable_pag :
413
- orig_class_name = orig_class_name .replace ("Pipeline" , "PAGPipeline" )
423
+ orig_class_name = orig_class_name .replace (to_replace , "PAGPipeline" )
414
424
415
425
text_2_image_cls = _get_task_class (AUTO_TEXT2IMAGE_PIPELINES_MAPPING , orig_class_name )
416
426
@@ -694,8 +704,14 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
694
704
695
705
# the `orig_class_name` can be:
696
706
# `- *Pipeline` (for regular text-to-image checkpoint)
707
+ # - `*ControlPipeline` (for Flux tools specific checkpoint)
697
708
# `- *Img2ImgPipeline` (for refiner checkpoint)
698
- to_replace = "Img2ImgPipeline" if "Img2Img" in config ["_class_name" ] else "Pipeline"
709
+ if "Img2Img" in orig_class_name :
710
+ to_replace = "Img2ImgPipeline"
711
+ elif "ControlPipeline" in orig_class_name :
712
+ to_replace = "ControlPipeline"
713
+ else :
714
+ to_replace = "Pipeline"
699
715
700
716
if "controlnet" in kwargs :
701
717
if isinstance (kwargs ["controlnet" ], ControlNetUnionModel ):
@@ -707,6 +723,9 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
707
723
if enable_pag :
708
724
orig_class_name = orig_class_name .replace (to_replace , "PAG" + to_replace )
709
725
726
+ if to_replace == "ControlPipeline" :
727
+ orig_class_name = orig_class_name .replace (to_replace , "ControlImg2ImgPipeline" )
728
+
710
729
image_2_image_cls = _get_task_class (AUTO_IMAGE2IMAGE_PIPELINES_MAPPING , orig_class_name )
711
730
712
731
kwargs = {** load_config_kwargs , ** kwargs }
@@ -994,8 +1013,14 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
994
1013
995
1014
# The `orig_class_name`` can be:
996
1015
# `- *InpaintPipeline` (for inpaint-specific checkpoint)
1016
+ # - `*ControlPipeline` (for Flux tools specific checkpoint)
997
1017
# - or *Pipeline (for regular text-to-image checkpoint)
998
- to_replace = "InpaintPipeline" if "Inpaint" in config ["_class_name" ] else "Pipeline"
1018
+ if "Inpaint" in orig_class_name :
1019
+ to_replace = "InpaintPipeline"
1020
+ elif "ControlPipeline" in orig_class_name :
1021
+ to_replace = "ControlPipeline"
1022
+ else :
1023
+ to_replace = "Pipeline"
999
1024
1000
1025
if "controlnet" in kwargs :
1001
1026
if isinstance (kwargs ["controlnet" ], ControlNetUnionModel ):
@@ -1006,6 +1031,8 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
1006
1031
enable_pag = kwargs .pop ("enable_pag" )
1007
1032
if enable_pag :
1008
1033
orig_class_name = orig_class_name .replace (to_replace , "PAG" + to_replace )
1034
+ if to_replace == "ControlPipeline" :
1035
+ orig_class_name = orig_class_name .replace (to_replace , "ControlInpaintPipeline" )
1009
1036
inpainting_cls = _get_task_class (AUTO_INPAINT_PIPELINES_MAPPING , orig_class_name )
1010
1037
1011
1038
kwargs = {** load_config_kwargs , ** kwargs }
0 commit comments