Skip to content

Commit b7246e1

Browse files
authored
gradio depth,normal,hough ok (#5873)
1 parent 4a83c56 commit b7246e1

File tree

6 files changed

+438
-1
lines changed

6 files changed

+438
-1
lines changed

ppdiffusers/examples/controlnet/README.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,35 @@ python gradio_hed2image.py
3939
python gradio_pose2image.py
4040
```
4141
![image](https://user-images.githubusercontent.com/20476674/222131475-4dc8582a-d2a2-447a-9724-85461de04c26.png)
42+
4243
## Semantic Segmentation to Image
4344
采用ADE20K分割协议的图片作为控制条件。
4445
```
4546
python gradio_seg2image_segmenter.py
4647
```
4748
![image](https://user-images.githubusercontent.com/20476674/222131908-b0c52512-ef42-4e4b-8fde-62c12c600ff2.png)
49+
50+
## Depth to Image
51+
采用Depth深度检测图片作为控制条件。
52+
```
53+
python gradio_depth2image.py
54+
```
55+
![image](https://user-images.githubusercontent.com/31800336/236171819-29085f22-c99c-4f63-b0a0-7cce6ac98ebc.jpg)
56+
57+
## Normal to Image
58+
采用Normal检测图片作为控制条件。
59+
```
60+
python gradio_normal2image.py
61+
```
62+
![image](https://user-images.githubusercontent.com/31800336/236171840-f31a4f1c-9997-41c0-83ca-4f87ca4cc870.jpg)
63+
64+
## Hough Line to Image
65+
采用HoughLine检测图片作为控制条件。
66+
```
67+
python gradio_hough2image.py
68+
```
69+
![image](https://user-images.githubusercontent.com/31800336/236171830-f9254b66-9fbd-46d3-a3bc-e905c87d0ec3.jpg)
70+
4871
# ControlNet模型训练
4972

5073
## Fill50K 训练例子

ppdiffusers/examples/controlnet/annotator/midas_paddle/api_inference.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ def __init__(self, model_dir, model_name="dpt_hybrid", batchsize=8, device="GPU"
6363
use_static=False,
6464
use_calib_mode=False,
6565
)
66+
min_input_shape = {"image": [1, 3, 224, 224]}
67+
max_input_shape = {"image": [1, 3, 1280, 1280]}
68+
opt_input_shape = {"image": [1, 3, 384, 384]}
69+
config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape, opt_input_shape)
6670

6771
# disable print log when predict
6872
config.disable_glog_info()

ppdiffusers/examples/controlnet/annotator/mlsd/models/mbv2_mlsd_large.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import paddle
16-
import utils
16+
from annotator.mlsd import utils
1717

1818

1919
class BlockTypeA(paddle.nn.Layer):
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
# Copyright 2023 The HuggingFace Team. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import random
17+
18+
import cv2
19+
import gradio as gr
20+
import paddle
21+
from annotator.midas_paddle import MidasDetector_Infer as MidasDetector
22+
from annotator.util import HWC3, resize_image
23+
24+
from paddlenlp.trainer import set_seed as seed_everything
25+
from ppdiffusers import ControlNetModel, StableDiffusionControlNetPipeline
26+
27+
apply_midas = MidasDetector()
28+
29+
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-depth")
30+
pipe = StableDiffusionControlNetPipeline.from_pretrained(
31+
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, safety_checker=None
32+
)
33+
34+
35+
def process(
36+
input_image,
37+
prompt,
38+
a_prompt,
39+
n_prompt,
40+
num_samples,
41+
image_resolution,
42+
detect_resolution,
43+
ddim_steps,
44+
guess_mode,
45+
strength,
46+
scale,
47+
seed,
48+
eta,
49+
):
50+
with paddle.no_grad():
51+
input_image = HWC3(input_image)
52+
detected_map, _ = apply_midas(resize_image(input_image, detect_resolution))
53+
detected_map = HWC3(detected_map)
54+
img = resize_image(input_image, image_resolution)
55+
H, W, C = img.shape
56+
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST)
57+
58+
control = paddle.to_tensor(detected_map.copy(), dtype=paddle.float32) / 255.0
59+
control = control.unsqueeze(0).transpose([0, 3, 1, 2])
60+
61+
control_scales = (
62+
[strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13)
63+
) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
64+
if seed == -1:
65+
seed = random.randint(0, 65535)
66+
seed_everything(seed)
67+
results = []
68+
for _ in range(num_samples):
69+
img = pipe(
70+
prompt + ", " + a_prompt,
71+
negative_prompt=n_prompt,
72+
image=control,
73+
num_inference_steps=ddim_steps,
74+
height=H,
75+
width=W,
76+
eta=eta,
77+
controlnet_conditioning_scale=control_scales,
78+
guidance_scale=scale,
79+
).images[0]
80+
results.append(img)
81+
82+
return [detected_map] + results
83+
84+
85+
block = gr.Blocks().queue()
86+
with block:
87+
with gr.Row():
88+
gr.Markdown("## Control Stable Diffusion with Depth Maps")
89+
with gr.Row():
90+
with gr.Column():
91+
input_image = gr.Image(source="upload", type="numpy")
92+
prompt = gr.Textbox(label="Prompt")
93+
run_button = gr.Button(label="Run")
94+
with gr.Accordion("Advanced options", open=False):
95+
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
96+
image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
97+
strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
98+
guess_mode = gr.Checkbox(label="Guess Mode", value=False)
99+
detect_resolution = gr.Slider(label="Depth Resolution", minimum=128, maximum=1024, value=384, step=1)
100+
ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
101+
scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
102+
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
103+
eta = gr.Number(label="eta (DDIM)", value=0.0)
104+
a_prompt = gr.Textbox(label="Added Prompt", value="best quality, extremely detailed")
105+
n_prompt = gr.Textbox(
106+
label="Negative Prompt",
107+
value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
108+
)
109+
with gr.Column():
110+
result_gallery = gr.Gallery(label="Output", show_label=False, elem_id="gallery").style(
111+
grid=2, height="auto"
112+
)
113+
ips = [
114+
input_image,
115+
prompt,
116+
a_prompt,
117+
n_prompt,
118+
num_samples,
119+
image_resolution,
120+
detect_resolution,
121+
ddim_steps,
122+
guess_mode,
123+
strength,
124+
scale,
125+
seed,
126+
eta,
127+
]
128+
run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
129+
130+
131+
block.launch(server_name="0.0.0.0")
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
# Copyright 2023 The HuggingFace Team. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import random
17+
18+
import cv2
19+
import gradio as gr
20+
import paddle
21+
from annotator.mlsd import MLSDdetector
22+
from annotator.util import HWC3, resize_image
23+
24+
from paddlenlp.trainer import set_seed as seed_everything
25+
from ppdiffusers import ControlNetModel, StableDiffusionControlNetPipeline
26+
27+
apply_mlsd = MLSDdetector()
28+
29+
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-mlsd")
30+
pipe = StableDiffusionControlNetPipeline.from_pretrained(
31+
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, safety_checker=None
32+
)
33+
34+
35+
def process(
36+
input_image,
37+
prompt,
38+
a_prompt,
39+
n_prompt,
40+
num_samples,
41+
image_resolution,
42+
detect_resolution,
43+
ddim_steps,
44+
guess_mode,
45+
strength,
46+
scale,
47+
seed,
48+
eta,
49+
value_threshold,
50+
distance_threshold,
51+
):
52+
with paddle.no_grad():
53+
input_image = HWC3(input_image)
54+
detected_map = apply_mlsd(resize_image(input_image, detect_resolution), value_threshold, distance_threshold)
55+
detected_map = HWC3(detected_map)
56+
img = resize_image(input_image, image_resolution)
57+
H, W, C = img.shape
58+
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST)
59+
60+
control = paddle.to_tensor(detected_map.copy(), dtype=paddle.float32) / 255.0
61+
control = control.unsqueeze(0).transpose([0, 3, 1, 2])
62+
63+
control_scales = (
64+
[strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13)
65+
) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
66+
if seed == -1:
67+
seed = random.randint(0, 65535)
68+
seed_everything(seed)
69+
results = []
70+
for _ in range(num_samples):
71+
img = pipe(
72+
prompt + ", " + a_prompt,
73+
negative_prompt=n_prompt,
74+
image=control,
75+
num_inference_steps=ddim_steps,
76+
height=H,
77+
width=W,
78+
eta=eta,
79+
controlnet_conditioning_scale=control_scales,
80+
guidance_scale=scale,
81+
).images[0]
82+
results.append(img)
83+
84+
return [detected_map] + results
85+
86+
87+
block = gr.Blocks().queue()
88+
with block:
89+
with gr.Row():
90+
gr.Markdown("## Control Stable Diffusion with Hough Line Maps")
91+
with gr.Row():
92+
with gr.Column():
93+
input_image = gr.Image(source="upload", type="numpy")
94+
prompt = gr.Textbox(label="Prompt")
95+
run_button = gr.Button(label="Run")
96+
with gr.Accordion("Advanced options", open=False):
97+
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
98+
image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
99+
strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
100+
guess_mode = gr.Checkbox(label="Guess Mode", value=False)
101+
detect_resolution = gr.Slider(
102+
label="Hough Line Resolution", minimum=128, maximum=1024, value=512, step=1
103+
)
104+
value_threshold = gr.Slider(
105+
label="Hough value threshold (MLSD)", minimum=0.01, maximum=2.0, value=0.1, step=0.01
106+
)
107+
distance_threshold = gr.Slider(
108+
label="Hough distance threshold (MLSD)", minimum=0.01, maximum=20.0, value=0.1, step=0.01
109+
)
110+
ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
111+
scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
112+
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
113+
eta = gr.Number(label="eta (DDIM)", value=0.0)
114+
a_prompt = gr.Textbox(label="Added Prompt", value="best quality, extremely detailed")
115+
n_prompt = gr.Textbox(
116+
label="Negative Prompt",
117+
value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
118+
)
119+
with gr.Column():
120+
result_gallery = gr.Gallery(label="Output", show_label=False, elem_id="gallery").style(
121+
grid=2, height="auto"
122+
)
123+
ips = [
124+
input_image,
125+
prompt,
126+
a_prompt,
127+
n_prompt,
128+
num_samples,
129+
image_resolution,
130+
detect_resolution,
131+
ddim_steps,
132+
guess_mode,
133+
strength,
134+
scale,
135+
seed,
136+
eta,
137+
value_threshold,
138+
distance_threshold,
139+
]
140+
run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
141+
142+
143+
block.launch(server_name="0.0.0.0")

0 commit comments

Comments
 (0)