Skip to content

Commit 0cb402a

Browse files
committed
enable custom vae for trt
1 parent 2f89960 commit 0cb402a

File tree

5 files changed

+15
-12
lines changed

5 files changed

+15
-12
lines changed

onnxruntime/python/tools/transformers/models/stable_diffusion/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ docker run --rm -it --gpus all -v $PWD:/workspace nvcr.io/nvidia/pytorch:23.10-p
4646

4747
Optionally, you can update TensorRT from 8.6.1 to latest pre-release.
4848
```
49+
python3 -m pip install --upgrade pip
4950
python3 -m pip install --pre --upgrade --extra-index-url https://pypi.nvidia.com tensorrt
5051
```
5152

@@ -60,7 +61,7 @@ sh build.sh --config Release --build_shared_lib --parallel --use_cuda --cuda_ve
6061
--cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF \
6162
--cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=80 \
6263
--allow_running_as_root
63-
python3 -m pip install build/Linux/Release/dist/onnxruntime_gpu-1.17.0-cp310-cp310-linux_x86_64.whl
64+
python3 -m pip install build/Linux/Release/dist/onnxruntime_gpu-1.17.0-cp310-cp310-linux_x86_64.whl --force-reinstall
6465
```
6566

6667
If the GPU is not A100, change `CMAKE_CUDA_ARCHITECTURES=80` in the command line according to the GPU compute capacity.

onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,13 +90,15 @@ def __init__(
9090
use_vae=False,
9191
min_image_size=256,
9292
max_image_size=1024,
93+
use_fp16_vae=True,
9394
):
9495
self.version = version
9596
self._is_inpaint = is_inpaint
9697
self._is_refiner = is_refiner
9798
self._use_vae = use_vae
9899
self._min_image_size = min_image_size
99100
self._max_image_size = max_image_size
101+
self._use_fp16_vae = use_fp16_vae
100102
if is_refiner:
101103
assert self.is_xl()
102104

@@ -127,6 +129,13 @@ def stages(self) -> List[str]:
127129
def vae_scaling_factor(self) -> float:
128130
return 0.13025 if self.is_xl() else 0.18215
129131

132+
def vae_torch_fallback(self) -> bool:
133+
return self.is_xl() and not self._use_fp16_vae
134+
135+
def custom_fp16_vae(self) -> Optional[str]:
136+
# For SD XL, use a VAE that fine-tuned to run in fp16 precision without generating NaNs
137+
return "madebyollin/sdxl-vae-fp16-fix" if self._use_fp16_vae and self.is_xl() else None
138+
130139
@staticmethod
131140
def supported_versions(is_xl: bool):
132141
return ["xl-1.0"] if is_xl else ["1.4", "1.5", "2.0-base", "2.0", "2.1", "2.1-base"]

onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,15 +60,8 @@ def __init__(
6060
self.torch_device = torch.device(device, torch.cuda.current_device())
6161
self.stages = pipeline_info.stages()
6262

63-
# TODO: use custom fp16 for ORT_TRT, and no need to fallback to torch.
64-
self.vae_torch_fallback = self.pipeline_info.is_xl() and engine_type != EngineType.ORT_CUDA
65-
66-
# For SD XL, use an VAE that modified to run in fp16 precision without generating NaNs.
67-
self.custom_fp16_vae = (
68-
"madebyollin/sdxl-vae-fp16-fix"
69-
if self.pipeline_info.is_xl() and self.engine_type == EngineType.ORT_CUDA
70-
else None
71-
)
63+
self.vae_torch_fallback = self.pipeline_info.vae_torch_fallback()
64+
self.custom_fp16_vae = self.pipeline_info.custom_fp16_vae()
7265

7366
self.models = {}
7467
self.engines = {}

onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,6 @@ def __init__(
104104

105105
self.stages = pipeline_info.stages()
106106

107-
self.vae_torch_fallback = self.pipeline_info.is_xl()
108-
109107
self.use_cuda_graph = use_cuda_graph
110108

111109
self.tokenizer = None

onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,5 @@ sympy
1313
optimum==1.13.1
1414
safetensors
1515
invisible_watermark
16+
# newer version of opencv-python migth encounter module 'cv2.dnn' has no attribute 'DictValue' error
17+
opencv-python==4.8.0.74

0 commit comments

Comments
 (0)