Skip to content

Commit 908fb4e

Browse files
Merge pull request #14390 from wangqyqq/sdxl-inpaint
Supporting for SDXL-Inpaint Model
2 parents c9c105c + bfe418a commit 908fb4e

File tree

4 files changed

+130
-1
lines changed

4 files changed

+130
-1
lines changed

configs/sd_xl_inpaint.yaml

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
model:
2+
target: sgm.models.diffusion.DiffusionEngine
3+
params:
4+
scale_factor: 0.13025
5+
disable_first_stage_autocast: True
6+
7+
denoiser_config:
8+
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
9+
params:
10+
num_idx: 1000
11+
12+
weighting_config:
13+
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
14+
scaling_config:
15+
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
16+
discretization_config:
17+
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
18+
19+
network_config:
20+
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
21+
params:
22+
adm_in_channels: 2816
23+
num_classes: sequential
24+
use_checkpoint: True
25+
in_channels: 9
26+
out_channels: 4
27+
model_channels: 320
28+
attention_resolutions: [4, 2]
29+
num_res_blocks: 2
30+
channel_mult: [1, 2, 4]
31+
num_head_channels: 64
32+
use_spatial_transformer: True
33+
use_linear_in_transformer: True
34+
transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
35+
context_dim: 2048
36+
spatial_transformer_attn_type: softmax-xformers
37+
legacy: False
38+
39+
conditioner_config:
40+
target: sgm.modules.GeneralConditioner
41+
params:
42+
emb_models:
43+
# crossattn cond
44+
- is_trainable: False
45+
input_key: txt
46+
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
47+
params:
48+
layer: hidden
49+
layer_idx: 11
50+
# crossattn and vector cond
51+
- is_trainable: False
52+
input_key: txt
53+
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
54+
params:
55+
arch: ViT-bigG-14
56+
version: laion2b_s39b_b160k
57+
freeze: True
58+
layer: penultimate
59+
always_return_pooled: True
60+
legacy: False
61+
# vector cond
62+
- is_trainable: False
63+
input_key: original_size_as_tuple
64+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
65+
params:
66+
outdim: 256 # multiplied by two
67+
# vector cond
68+
- is_trainable: False
69+
input_key: crop_coords_top_left
70+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
71+
params:
72+
outdim: 256 # multiplied by two
73+
# vector cond
74+
- is_trainable: False
75+
input_key: target_size_as_tuple
76+
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
77+
params:
78+
outdim: 256 # multiplied by two
79+
80+
first_stage_config:
81+
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
82+
params:
83+
embed_dim: 4
84+
monitor: val/rec_loss
85+
ddconfig:
86+
attn_type: vanilla-xformers
87+
double_z: true
88+
z_channels: 4
89+
resolution: 256
90+
in_channels: 3
91+
out_ch: 3
92+
ch: 128
93+
ch_mult: [1, 2, 4, 4]
94+
num_res_blocks: 2
95+
attn_resolutions: []
96+
dropout: 0.0
97+
lossconfig:
98+
target: torch.nn.Identity

modules/processing.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,21 @@ def txt2img_image_conditioning(sd_model, x, width, height):
113113
return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device)
114114

115115
else:
116+
sd = sd_model.model.state_dict()
117+
diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None)
118+
if diffusion_model_input is not None:
119+
if diffusion_model_input.shape[1] == 9:
120+
# The "masked-image" in this case will just be all 0.5 since the entire image is masked.
121+
image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5
122+
image_conditioning = images_tensor_to_samples(image_conditioning,
123+
approximation_indexes.get(opts.sd_vae_encode_method))
124+
125+
# Add the fake full 1s mask to the first dimension.
126+
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
127+
image_conditioning = image_conditioning.to(x.dtype)
128+
129+
return image_conditioning
130+
116131
# Dummy zero conditioning if we're not using inpainting or unclip models.
117132
# Still takes up a bit of memory, but no encoder call.
118133
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
@@ -371,6 +386,12 @@ def img2img_image_conditioning(self, source_image, latent_image, image_mask=None
371386
if self.sampler.conditioning_key == "crossattn-adm":
372387
return self.unclip_image_conditioning(source_image)
373388

389+
sd = self.sampler.model_wrap.inner_model.model.state_dict()
390+
diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None)
391+
if diffusion_model_input is not None:
392+
if diffusion_model_input.shape[1] == 9:
393+
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
394+
374395
# Dummy zero conditioning if we're not using inpainting or depth model.
375396
return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
376397

modules/sd_models_config.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
1616
config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml")
1717
config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml")
18+
config_sdxl_inpainting = os.path.join(sd_configs_path, "sd_xl_inpaint.yaml")
1819
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
1920
config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")
2021
config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml")
@@ -71,7 +72,10 @@ def guess_model_config_from_state_dict(sd, filename):
7172
sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)
7273

7374
if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None:
74-
return config_sdxl
75+
if diffusion_model_input.shape[1] == 9:
76+
return config_sdxl_inpainting
77+
else:
78+
return config_sdxl
7579
if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None:
7680
return config_sdxl_refiner
7781
elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:

modules/sd_models_xl.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch:
3434

3535

3636
def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond):
37+
sd = self.model.state_dict()
38+
diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None)
39+
if diffusion_model_input is not None:
40+
if diffusion_model_input.shape[1] == 9:
41+
x = torch.cat([x] + cond['c_concat'], dim=1)
42+
3743
return self.model(x, t, cond)
3844

3945

0 commit comments

Comments
 (0)