Skip to content

Commit 96c376a

Browse files
a-r-r-o-wstevhliuyiyixuxu
authored
[core] LTX Video (#10021)
* transformer * make style & make fix-copies * transformer * add transformer tests * 80% vae * make style * make fix-copies * fix * undo cogvideox changes * update * update * match vae * add docs * t2v pipeline working; scheduler needs to be checked * docs * add pipeline test * update * update * make fix-copies * Apply suggestions from code review Co-authored-by: Steven Liu <[email protected]> * update * copy t2v to i2v pipeline * update * apply review suggestions * update * make style * remove framewise encoding/decoding * pack/unpack latents * image2video * update * make fix-copies * update * update * rope scale fix * debug layerwise code * remove debug * Apply suggestions from code review Co-authored-by: YiYi Xu <[email protected]> * propagate precision changes to i2v pipeline * remove downcast * address review comments * fix comment * address review comments * [Single File] LTX support for loading original weights (#10135) * from original file mixin for ltx * undo config mapping fn changes * update * add single file to pipelines * update docs * Update src/diffusers/models/autoencoders/autoencoder_kl_ltx.py * Update src/diffusers/models/autoencoders/autoencoder_kl_ltx.py * rename classes based on ltx review * point to original repository for inference * make style * resolve conflicts correctly --------- Co-authored-by: Steven Liu <[email protected]> Co-authored-by: YiYi Xu <[email protected]>
1 parent 8170dc3 commit 96c376a

26 files changed

+4439
-1
lines changed

docs/source/en/_toctree.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,8 @@
274274
title: LatteTransformer3DModel
275275
- local: api/models/lumina_nextdit2d
276276
title: LuminaNextDiT2DModel
277+
- local: api/models/ltx_video_transformer3d
278+
title: LTXVideoTransformer3DModel
277279
- local: api/models/mochi_transformer3d
278280
title: MochiTransformer3DModel
279281
- local: api/models/pixart_transformer2d
@@ -312,6 +314,8 @@
312314
title: AutoencoderKLAllegro
313315
- local: api/models/autoencoderkl_cogvideox
314316
title: AutoencoderKLCogVideoX
317+
- local: api/models/autoencoderkl_ltx_video
318+
title: AutoencoderKLLTXVideo
315319
- local: api/models/autoencoderkl_mochi
316320
title: AutoencoderKLMochi
317321
- local: api/models/asymmetricautoencoderkl
@@ -408,6 +412,8 @@
408412
title: Latte
409413
- local: api/pipelines/ledits_pp
410414
title: LEDITS++
415+
- local: api/pipelines/ltx_video
416+
title: LTX
411417
- local: api/pipelines/lumina
412418
title: Lumina-T2X
413419
- local: api/pipelines/marigold
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License. -->
11+
12+
# AutoencoderKLLTXVideo
13+
14+
The 3D variational autoencoder (VAE) model with KL loss used in [LTX](https://huggingface.co/Lightricks/LTX-Video) was introduced by Lightricks.
15+
16+
The model can be loaded with the following code snippet.
17+
18+
```python
19+
from diffusers import AutoencoderKLLTXVideo
20+
21+
vae = AutoencoderKLLTXVideo.from_pretrained("TODO/TODO", subfolder="vae", torch_dtype=torch.float32).to("cuda")
22+
```
23+
24+
## AutoencoderKLLTXVideo
25+
26+
[[autodoc]] AutoencoderKLLTXVideo
27+
- decode
28+
- encode
29+
- all
30+
31+
## AutoencoderKLOutput
32+
33+
[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput
34+
35+
## DecoderOutput
36+
37+
[[autodoc]] models.autoencoders.vae.DecoderOutput
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License. -->
11+
12+
# LTXVideoTransformer3DModel
13+
14+
A Diffusion Transformer model for 3D data from [LTX](https://huggingface.co/Lightricks/LTX-Video) was introduced by Lightricks.
15+
16+
The model can be loaded with the following code snippet.
17+
18+
```python
19+
from diffusers import LTXVideoTransformer3DModel
20+
21+
transformer = LTXVideoTransformer3DModel.from_pretrained("TODO/TODO", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
22+
```
23+
24+
## LTXVideoTransformer3DModel
25+
26+
[[autodoc]] LTXVideoTransformer3DModel
27+
28+
## Transformer2DModelOutput
29+
30+
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License. -->
14+
15+
# LTX
16+
17+
[LTX Video](https://huggingface.co/Lightricks/LTX-Video) is the first DiT-based video generation model capable of generating high-quality videos in real-time. It produces 24 FPS videos at a 768x512 resolution faster than they can be watched. Trained on a large-scale dataset of diverse videos, the model generates high-resolution videos with realistic and varied content. We provide a model for both text-to-video as well as image + text-to-video usecases.
18+
19+
<Tip>
20+
21+
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
22+
23+
</Tip>
24+
25+
## Loading Single Files
26+
27+
Loading the original LTX Video checkpoints is also possible with [`~ModelMixin.from_single_file`].
28+
29+
```python
30+
import torch
31+
from diffusers import AutoencoderKLLTXVideo, LTXImageToVideoPipeline, LTXVideoTransformer3DModel
32+
33+
single_file_url = "https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.safetensors"
34+
transformer = LTXVideoTransformer3DModel.from_single_file(single_file_url, torch_dtype=torch.bfloat16)
35+
vae = AutoencoderKLLTXVideo.from_single_file(single_file_url, torch_dtype=torch.bfloat16)
36+
pipe = LTXImageToVideoPipeline.from_pretrained("Lightricks/LTX-Video", transformer=transformer, vae=vae, torch_dtype=torch.bfloat16)
37+
38+
# ... inference code ...
39+
```
40+
41+
Alternatively, the pipeline can be used to load the weights with [~FromSingleFileMixin.from_single_file`].
42+
43+
```python
44+
import torch
45+
from diffusers import LTXImageToVideoPipeline
46+
from transformers import T5EncoderModel, T5Tokenizer
47+
48+
single_file_url = "https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.safetensors"
49+
text_encoder = T5EncoderModel.from_pretrained("Lightricks/LTX-Video", subfolder="text_encoder", torch_dtype=torch.bfloat16)
50+
tokenizer = T5Tokenizer.from_pretrained("Lightricks/LTX-Video", subfolder="tokenizer", torch_dtype=torch.bfloat16)
51+
pipe = LTXImageToVideoPipeline.from_single_file(single_file_url, text_encoder=text_encoder, tokenizer=tokenizer, torch_dtype=torch.bfloat16)
52+
```
53+
54+
## LTXPipeline
55+
56+
[[autodoc]] LTXPipeline
57+
- all
58+
- __call__
59+
60+
## LTXImageToVideoPipeline
61+
62+
[[autodoc]] LTXImageToVideoPipeline
63+
- all
64+
- __call__
65+
66+
## LTXPipelineOutput
67+
68+
[[autodoc]] pipelines.ltx.pipeline_output.LTXPipelineOutput

scripts/convert_ltx_to_diffusers.py

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
import argparse
2+
from typing import Any, Dict
3+
4+
import torch
5+
from safetensors.torch import load_file
6+
from transformers import T5EncoderModel, T5Tokenizer
7+
8+
from diffusers import AutoencoderKLLTXVideo, FlowMatchEulerDiscreteScheduler, LTXPipeline, LTXVideoTransformer3DModel
9+
10+
11+
def remove_keys_(key: str, state_dict: Dict[str, Any]):
12+
state_dict.pop(key)
13+
14+
15+
TOKENIZER_MAX_LENGTH = 128
16+
17+
TRANSFORMER_KEYS_RENAME_DICT = {
18+
"patchify_proj": "proj_in",
19+
"adaln_single": "time_embed",
20+
"q_norm": "norm_q",
21+
"k_norm": "norm_k",
22+
}
23+
24+
TRANSFORMER_SPECIAL_KEYS_REMAP = {}
25+
26+
VAE_KEYS_RENAME_DICT = {
27+
# decoder
28+
"up_blocks.0": "mid_block",
29+
"up_blocks.1": "up_blocks.0",
30+
"up_blocks.2": "up_blocks.1.upsamplers.0",
31+
"up_blocks.3": "up_blocks.1",
32+
"up_blocks.4": "up_blocks.2.conv_in",
33+
"up_blocks.5": "up_blocks.2.upsamplers.0",
34+
"up_blocks.6": "up_blocks.2",
35+
"up_blocks.7": "up_blocks.3.conv_in",
36+
"up_blocks.8": "up_blocks.3.upsamplers.0",
37+
"up_blocks.9": "up_blocks.3",
38+
# encoder
39+
"down_blocks.0": "down_blocks.0",
40+
"down_blocks.1": "down_blocks.0.downsamplers.0",
41+
"down_blocks.2": "down_blocks.0.conv_out",
42+
"down_blocks.3": "down_blocks.1",
43+
"down_blocks.4": "down_blocks.1.downsamplers.0",
44+
"down_blocks.5": "down_blocks.1.conv_out",
45+
"down_blocks.6": "down_blocks.2",
46+
"down_blocks.7": "down_blocks.2.downsamplers.0",
47+
"down_blocks.8": "down_blocks.3",
48+
"down_blocks.9": "mid_block",
49+
# common
50+
"conv_shortcut": "conv_shortcut.conv",
51+
"res_blocks": "resnets",
52+
"norm3.norm": "norm3",
53+
"per_channel_statistics.mean-of-means": "latents_mean",
54+
"per_channel_statistics.std-of-means": "latents_std",
55+
}
56+
57+
VAE_SPECIAL_KEYS_REMAP = {
58+
"per_channel_statistics.channel": remove_keys_,
59+
"per_channel_statistics.mean-of-means": remove_keys_,
60+
"per_channel_statistics.mean-of-stds": remove_keys_,
61+
}
62+
63+
64+
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
65+
state_dict = saved_dict
66+
if "model" in saved_dict.keys():
67+
state_dict = state_dict["model"]
68+
if "module" in saved_dict.keys():
69+
state_dict = state_dict["module"]
70+
if "state_dict" in saved_dict.keys():
71+
state_dict = state_dict["state_dict"]
72+
return state_dict
73+
74+
75+
def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
76+
state_dict[new_key] = state_dict.pop(old_key)
77+
78+
79+
def convert_transformer(
80+
ckpt_path: str,
81+
dtype: torch.dtype,
82+
):
83+
PREFIX_KEY = ""
84+
85+
original_state_dict = get_state_dict(load_file(ckpt_path))
86+
transformer = LTXVideoTransformer3DModel().to(dtype=dtype)
87+
88+
for key in list(original_state_dict.keys()):
89+
new_key = key[len(PREFIX_KEY) :]
90+
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
91+
new_key = new_key.replace(replace_key, rename_key)
92+
update_state_dict_inplace(original_state_dict, key, new_key)
93+
94+
for key in list(original_state_dict.keys()):
95+
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
96+
if special_key not in key:
97+
continue
98+
handler_fn_inplace(key, original_state_dict)
99+
100+
transformer.load_state_dict(original_state_dict, strict=True)
101+
return transformer
102+
103+
104+
def convert_vae(ckpt_path: str, dtype: torch.dtype):
105+
original_state_dict = get_state_dict(load_file(ckpt_path))
106+
vae = AutoencoderKLLTXVideo().to(dtype=dtype)
107+
108+
for key in list(original_state_dict.keys()):
109+
new_key = key[:]
110+
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
111+
new_key = new_key.replace(replace_key, rename_key)
112+
update_state_dict_inplace(original_state_dict, key, new_key)
113+
114+
for key in list(original_state_dict.keys()):
115+
for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
116+
if special_key not in key:
117+
continue
118+
handler_fn_inplace(key, original_state_dict)
119+
120+
vae.load_state_dict(original_state_dict, strict=True)
121+
return vae
122+
123+
124+
def get_args():
125+
parser = argparse.ArgumentParser()
126+
parser.add_argument(
127+
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
128+
)
129+
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
130+
parser.add_argument(
131+
"--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
132+
)
133+
parser.add_argument(
134+
"--typecast_text_encoder",
135+
action="store_true",
136+
default=False,
137+
help="Whether or not to apply fp16/bf16 precision to text_encoder",
138+
)
139+
parser.add_argument("--save_pipeline", action="store_true")
140+
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
141+
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
142+
return parser.parse_args()
143+
144+
145+
DTYPE_MAPPING = {
146+
"fp32": torch.float32,
147+
"fp16": torch.float16,
148+
"bf16": torch.bfloat16,
149+
}
150+
151+
VARIANT_MAPPING = {
152+
"fp32": None,
153+
"fp16": "fp16",
154+
"bf16": "bf16",
155+
}
156+
157+
158+
if __name__ == "__main__":
159+
args = get_args()
160+
161+
transformer = None
162+
dtype = DTYPE_MAPPING[args.dtype]
163+
variant = VARIANT_MAPPING[args.dtype]
164+
165+
if args.save_pipeline:
166+
assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None
167+
168+
if args.transformer_ckpt_path is not None:
169+
transformer: LTXVideoTransformer3DModel = convert_transformer(args.transformer_ckpt_path, dtype)
170+
if not args.save_pipeline:
171+
transformer.save_pretrained(
172+
args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant
173+
)
174+
175+
if args.vae_ckpt_path is not None:
176+
vae: AutoencoderKLLTXVideo = convert_vae(args.vae_ckpt_path, dtype)
177+
if not args.save_pipeline:
178+
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant)
179+
180+
if args.save_pipeline:
181+
text_encoder_id = "google/t5-v1_1-xxl"
182+
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
183+
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
184+
185+
if args.typecast_text_encoder:
186+
text_encoder = text_encoder.to(dtype=dtype)
187+
188+
# Apparently, the conversion does not work anymore without this :shrug:
189+
for param in text_encoder.parameters():
190+
param.data = param.data.contiguous()
191+
192+
scheduler = FlowMatchEulerDiscreteScheduler(
193+
use_dynamic_shifting=True,
194+
base_shift=0.95,
195+
max_shift=2.05,
196+
base_image_seq_len=1024,
197+
max_image_seq_len=4096,
198+
shift_terminal=0.1,
199+
)
200+
201+
pipe = LTXPipeline(
202+
scheduler=scheduler,
203+
vae=vae,
204+
text_encoder=text_encoder,
205+
tokenizer=tokenizer,
206+
transformer=transformer,
207+
)
208+
209+
pipe.save_pretrained(args.output_path, safe_serialization=True, variant=variant, max_shard_size="5GB")

0 commit comments

Comments
 (0)