Skip to content

Commit 8f3e45f

Browse files
hlkysayakpaul
authored andcommitted
Add Flux Control to AutoPipeline (#10292)
1 parent 6d7896e commit 8f3e45f

File tree

1 file changed

+32
-5
lines changed

1 file changed

+32
-5
lines changed

src/diffusers/pipelines/auto_pipeline.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,12 @@
3535
)
3636
from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline
3737
from .flux import (
38+
FluxControlImg2ImgPipeline,
39+
FluxControlInpaintPipeline,
3840
FluxControlNetImg2ImgPipeline,
3941
FluxControlNetInpaintPipeline,
4042
FluxControlNetPipeline,
43+
FluxControlPipeline,
4144
FluxImg2ImgPipeline,
4245
FluxInpaintPipeline,
4346
FluxPipeline,
@@ -125,6 +128,7 @@
125128
("pixart-sigma-pag", PixArtSigmaPAGPipeline),
126129
("auraflow", AuraFlowPipeline),
127130
("flux", FluxPipeline),
131+
("flux-control", FluxControlPipeline),
128132
("flux-controlnet", FluxControlNetPipeline),
129133
("lumina", LuminaText2ImgPipeline),
130134
("cogview3", CogView3PlusPipeline),
@@ -150,6 +154,7 @@
150154
("lcm", LatentConsistencyModelImg2ImgPipeline),
151155
("flux", FluxImg2ImgPipeline),
152156
("flux-controlnet", FluxControlNetImg2ImgPipeline),
157+
("flux-control", FluxControlImg2ImgPipeline),
153158
]
154159
)
155160

@@ -168,6 +173,7 @@
168173
("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline),
169174
("flux", FluxInpaintPipeline),
170175
("flux-controlnet", FluxControlNetInpaintPipeline),
176+
("flux-control", FluxControlInpaintPipeline),
171177
("stable-diffusion-pag", StableDiffusionPAGInpaintPipeline),
172178
]
173179
)
@@ -401,16 +407,20 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
401407

402408
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
403409
orig_class_name = config["_class_name"]
410+
if "ControlPipeline" in orig_class_name:
411+
to_replace = "ControlPipeline"
412+
else:
413+
to_replace = "Pipeline"
404414

405415
if "controlnet" in kwargs:
406416
if isinstance(kwargs["controlnet"], ControlNetUnionModel):
407-
orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetUnionPipeline")
417+
orig_class_name = config["_class_name"].replace(to_replace, "ControlNetUnionPipeline")
408418
else:
409-
orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline")
419+
orig_class_name = config["_class_name"].replace(to_replace, "ControlNetPipeline")
410420
if "enable_pag" in kwargs:
411421
enable_pag = kwargs.pop("enable_pag")
412422
if enable_pag:
413-
orig_class_name = orig_class_name.replace("Pipeline", "PAGPipeline")
423+
orig_class_name = orig_class_name.replace(to_replace, "PAGPipeline")
414424

415425
text_2_image_cls = _get_task_class(AUTO_TEXT2IMAGE_PIPELINES_MAPPING, orig_class_name)
416426

@@ -694,8 +704,14 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
694704

695705
# the `orig_class_name` can be:
696706
# `- *Pipeline` (for regular text-to-image checkpoint)
707+
# - `*ControlPipeline` (for Flux tools specific checkpoint)
697708
# `- *Img2ImgPipeline` (for refiner checkpoint)
698-
to_replace = "Img2ImgPipeline" if "Img2Img" in config["_class_name"] else "Pipeline"
709+
if "Img2Img" in orig_class_name:
710+
to_replace = "Img2ImgPipeline"
711+
elif "ControlPipeline" in orig_class_name:
712+
to_replace = "ControlPipeline"
713+
else:
714+
to_replace = "Pipeline"
699715

700716
if "controlnet" in kwargs:
701717
if isinstance(kwargs["controlnet"], ControlNetUnionModel):
@@ -707,6 +723,9 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
707723
if enable_pag:
708724
orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace)
709725

726+
if to_replace == "ControlPipeline":
727+
orig_class_name = orig_class_name.replace(to_replace, "ControlImg2ImgPipeline")
728+
710729
image_2_image_cls = _get_task_class(AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, orig_class_name)
711730

712731
kwargs = {**load_config_kwargs, **kwargs}
@@ -994,8 +1013,14 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
9941013

9951014
# The `orig_class_name`` can be:
9961015
# `- *InpaintPipeline` (for inpaint-specific checkpoint)
1016+
# - `*ControlPipeline` (for Flux tools specific checkpoint)
9971017
# - or *Pipeline (for regular text-to-image checkpoint)
998-
to_replace = "InpaintPipeline" if "Inpaint" in config["_class_name"] else "Pipeline"
1018+
if "Inpaint" in orig_class_name:
1019+
to_replace = "InpaintPipeline"
1020+
elif "ControlPipeline" in orig_class_name:
1021+
to_replace = "ControlPipeline"
1022+
else:
1023+
to_replace = "Pipeline"
9991024

10001025
if "controlnet" in kwargs:
10011026
if isinstance(kwargs["controlnet"], ControlNetUnionModel):
@@ -1006,6 +1031,8 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
10061031
enable_pag = kwargs.pop("enable_pag")
10071032
if enable_pag:
10081033
orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace)
1034+
if to_replace == "ControlPipeline":
1035+
orig_class_name = orig_class_name.replace(to_replace, "ControlInpaintPipeline")
10091036
inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, orig_class_name)
10101037

10111038
kwargs = {**load_config_kwargs, **kwargs}

0 commit comments

Comments
 (0)