Skip to content

[core] LTX Video #10021

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 64 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from 58 commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
b28f89d
transformer
a-r-r-o-w Nov 25, 2024
f082cc8
make style & make fix-copies
a-r-r-o-w Nov 25, 2024
a255045
transformer
a-r-r-o-w Nov 26, 2024
36c9b40
add transformer tests
a-r-r-o-w Nov 26, 2024
c3bd2e4
80% vae
a-r-r-o-w Nov 26, 2024
43f7907
make style
a-r-r-o-w Nov 26, 2024
02a2b6b
make fix-copies
a-r-r-o-w Nov 26, 2024
c901641
fix
a-r-r-o-w Nov 26, 2024
868cd47
undo cogvideox changes
a-r-r-o-w Nov 27, 2024
db13a83
update
a-r-r-o-w Nov 27, 2024
11d2d91
update
a-r-r-o-w Nov 27, 2024
d320105
match vae
a-r-r-o-w Nov 27, 2024
755e29c
add docs
a-r-r-o-w Nov 27, 2024
ac95930
t2v pipeline working; scheduler needs to be checked
a-r-r-o-w Nov 27, 2024
5f185cd
docs
a-r-r-o-w Nov 27, 2024
e580b6b
add pipeline test
a-r-r-o-w Nov 27, 2024
13adf3f
update
a-r-r-o-w Nov 27, 2024
c8dfa98
update
a-r-r-o-w Nov 27, 2024
b234394
make fix-copies
a-r-r-o-w Nov 27, 2024
e379200
Merge branch 'main' into ltx-integration
a-r-r-o-w Nov 27, 2024
6544fcc
Apply suggestions from code review
a-r-r-o-w Nov 27, 2024
7134e2d
update
a-r-r-o-w Nov 28, 2024
d4a0f8e
copy t2v to i2v pipeline
a-r-r-o-w Nov 28, 2024
a1fe164
Merge branch 'ltx-integration' of https://github.com/huggingface/diff…
a-r-r-o-w Nov 28, 2024
f3a0b0a
Merge branch 'main' into ltx-integration
a-r-r-o-w Nov 28, 2024
8a26886
Merge branch 'ltx-integration' of https://github.com/huggingface/diff…
a-r-r-o-w Nov 28, 2024
06db66b
update
a-r-r-o-w Nov 28, 2024
f8f30a5
apply review suggestions
a-r-r-o-w Nov 28, 2024
4e89c8d
update
a-r-r-o-w Nov 28, 2024
5391ceb
make style
a-r-r-o-w Nov 28, 2024
c201880
remove framewise encoding/decoding
a-r-r-o-w Nov 29, 2024
30a3bb7
pack/unpack latents
a-r-r-o-w Nov 29, 2024
e10b7e7
Merge branch 'main' into ltx-integration
a-r-r-o-w Nov 29, 2024
1f008fc
image2video
a-r-r-o-w Nov 29, 2024
8e16389
update
a-r-r-o-w Nov 29, 2024
57c41df
make fix-copies
a-r-r-o-w Nov 29, 2024
606e6b2
update
a-r-r-o-w Nov 29, 2024
f4b5341
update
a-r-r-o-w Nov 29, 2024
d556b7f
rope scale fix
a-r-r-o-w Nov 30, 2024
42ca5e6
debug layerwise code
a-r-r-o-w Nov 30, 2024
2502399
remove debug
a-r-r-o-w Nov 30, 2024
8c9d3d0
Apply suggestions from code review
a-r-r-o-w Dec 1, 2024
eb962d1
propagate precision changes to i2v pipeline
a-r-r-o-w Dec 1, 2024
6dfca2a
Merge branch 'main' into ltx-integration
a-r-r-o-w Dec 1, 2024
5196b2a
remove downcast
a-r-r-o-w Dec 1, 2024
d76232d
address review comments
a-r-r-o-w Dec 2, 2024
7f4edfb
Merge branch 'main' into ltx-integration
a-r-r-o-w Dec 2, 2024
f18cf1a
fix comment
a-r-r-o-w Dec 2, 2024
da475ec
Merge branch 'main' into ltx-integration
a-r-r-o-w Dec 2, 2024
4e8b2a4
address review comments
a-r-r-o-w Dec 4, 2024
336ba36
Merge branch 'main' into ltx-integration
a-r-r-o-w Dec 4, 2024
9ba6a06
[Single File] LTX support for loading original weights (#10135)
a-r-r-o-w Dec 10, 2024
9f9e016
Merge branch 'main' into ltx-integration
a-r-r-o-w Dec 10, 2024
db16983
add single file to pipelines
a-r-r-o-w Dec 10, 2024
69400de
update docs
a-r-r-o-w Dec 10, 2024
f5c4815
Update src/diffusers/models/autoencoders/autoencoder_kl_ltx.py
a-r-r-o-w Dec 11, 2024
2106441
Update src/diffusers/models/autoencoders/autoencoder_kl_ltx.py
a-r-r-o-w Dec 11, 2024
d997c7b
Merge branch 'main' into ltx-integration
a-r-r-o-w Dec 11, 2024
4aa7896
rename classes based on ltx review
a-r-r-o-w Dec 11, 2024
93d93b1
point to original repository for inference
a-r-r-o-w Dec 11, 2024
74f186e
Merge branch 'main' into ltx-integration
a-r-r-o-w Dec 11, 2024
c9a9ab5
make style
a-r-r-o-w Dec 11, 2024
1e67968
resolve conflicts correctly
a-r-r-o-w Dec 11, 2024
bee7475
Merge branch 'main' into ltx-integration
a-r-r-o-w Dec 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,8 @@
title: LatteTransformer3DModel
- local: api/models/lumina_nextdit2d
title: LuminaNextDiT2DModel
- local: api/models/ltx_transformer3d
title: LTXTransformer3DModel
- local: api/models/mochi_transformer3d
title: MochiTransformer3DModel
- local: api/models/pixart_transformer2d
Expand Down Expand Up @@ -310,6 +312,8 @@
title: AutoencoderKLAllegro
- local: api/models/autoencoderkl_cogvideox
title: AutoencoderKLCogVideoX
- local: api/models/autoencoderkl_ltx
title: AutoencoderKLLTX
- local: api/models/autoencoderkl_mochi
title: AutoencoderKLMochi
- local: api/models/asymmetricautoencoderkl
Expand Down Expand Up @@ -404,6 +408,8 @@
title: Latte
- local: api/pipelines/ledits_pp
title: LEDITS++
- local: api/pipelines/ltx
title: LTX
- local: api/pipelines/lumina
title: Lumina-T2X
- local: api/pipelines/marigold
Expand Down
37 changes: 37 additions & 0 deletions docs/source/en/api/models/autoencoderkl_ltx.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->

# AutoencoderKLLTX

The 3D variational autoencoder (VAE) model with KL loss used in [LTX](https://huggingface.co/Lightricks/LTX-Video) was introduced by Lightricks.

The model can be loaded with the following code snippet.

```python
from diffusers import AutoencoderKLLTX

vae = AutoencoderKLLTX.from_pretrained("TODO/TODO", subfolder="vae", torch_dtype=torch.float32).to("cuda")
```

## AutoencoderKLLTX

[[autodoc]] AutoencoderKLLTX
- decode
- encode
- all

## AutoencoderKLOutput

[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput

## DecoderOutput

[[autodoc]] models.autoencoders.vae.DecoderOutput
30 changes: 30 additions & 0 deletions docs/source/en/api/models/ltx_transformer3d.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->

# LTXTransformer3DModel

A Diffusion Transformer model for 3D data from [LTX](https://huggingface.co/Lightricks/LTX-Video) was introduced by Lightricks.

The model can be loaded with the following code snippet.

```python
from diffusers import LTXTransformer3DModel

transformer = LTXTransformer3DModel.from_pretrained("TODO/TODO", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
```

## LTXTransformer3DModel

[[autodoc]] LTXTransformer3DModel

## Transformer2DModelOutput

[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
68 changes: 68 additions & 0 deletions docs/source/en/api/pipelines/ltx.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. -->

# LTX

[LTX Video](https://huggingface.co/Lightricks/LTX-Video) is the first DiT-based video generation model capable of generating high-quality videos in real-time. It produces 24 FPS videos at a 768x512 resolution faster than they can be watched. Trained on a large-scale dataset of diverse videos, the model generates high-resolution videos with realistic and varied content. We provide a model for both text-to-video as well as image + text-to-video usecases.

<Tip>

Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.

</Tip>

## Loading Single Files

Loading the original LTX Video checkpoints is also possible with [`~ModelMixin.from_single_file`].

```python
import torch
from diffusers import AutoencoderKLLTX, LTXImageToVideoPipeline, LTXTransformer3DModel

single_file_url = "https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.safetensors"
transformer = LTXTransformer3DModel.from_single_file(single_file_url, torch_dtype=torch.bfloat16)
vae = AutoencoderKLLTX.from_single_file(single_file_url, torch_dtype=torch.bfloat16)
pipe = LTXImageToVideoPipeline.from_pretrained("Lightricks/LTX-Video", transformer=transformer, vae=vae, torch_dtype=torch.bfloat16)

# ... inference code ...
```

Alternative, the pipeline can be used to load the weights with [~FromSingleFileMixin.from_single_file`].

```python
import torch
from diffusers import LTXImageToVideoPipeline
from transformers import T5EncoderModel, T5Tokenizer

single_file_url = "https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.safetensors"
text_encoder = T5EncoderModel.from_pretrained("Lightricks/LTX-Video", subfolder="text_encoder", torch_dtype=torch.bfloat16)
tokenizer = T5Tokenizer.from_pretrained("Lightricks/LTX-Video", subfolder="tokenizer", torch_dtype=torch.bfloat16)
pipe = LTXImageToVideoPipeline.from_single_file(single_file_url, text_encoder=text_encoder, tokenizer=tokenizer, torch_dtype=torch.bfloat16)
```

## LTXPipeline

[[autodoc]] LTXPipeline
- all
- __call__

## LTXImageToVideoPipeline

[[autodoc]] LTXImageToVideoPipeline
- all
- __call__

## LTXPipelineOutput

[[autodoc]] pipelines.ltx.pipeline_output.LTXPipelineOutput
209 changes: 209 additions & 0 deletions scripts/convert_ltx_to_diffusers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
import argparse
from typing import Any, Dict

import torch
from safetensors.torch import load_file
from transformers import T5EncoderModel, T5Tokenizer

from diffusers import AutoencoderKLLTX, FlowMatchEulerDiscreteScheduler, LTXPipeline, LTXTransformer3DModel


def remove_keys_(key: str, state_dict: Dict[str, Any]):
state_dict.pop(key)


TOKENIZER_MAX_LENGTH = 128

TRANSFORMER_KEYS_RENAME_DICT = {
"patchify_proj": "proj_in",
"adaln_single": "time_embed",
"q_norm": "norm_q",
"k_norm": "norm_k",
}

TRANSFORMER_SPECIAL_KEYS_REMAP = {}

VAE_KEYS_RENAME_DICT = {
# decoder
"up_blocks.0": "mid_block",
"up_blocks.1": "up_blocks.0",
"up_blocks.2": "up_blocks.1.upsamplers.0",
"up_blocks.3": "up_blocks.1",
"up_blocks.4": "up_blocks.2.conv_in",
"up_blocks.5": "up_blocks.2.upsamplers.0",
"up_blocks.6": "up_blocks.2",
"up_blocks.7": "up_blocks.3.conv_in",
"up_blocks.8": "up_blocks.3.upsamplers.0",
"up_blocks.9": "up_blocks.3",
# encoder
"down_blocks.0": "down_blocks.0",
"down_blocks.1": "down_blocks.0.downsamplers.0",
"down_blocks.2": "down_blocks.0.conv_out",
"down_blocks.3": "down_blocks.1",
"down_blocks.4": "down_blocks.1.downsamplers.0",
"down_blocks.5": "down_blocks.1.conv_out",
"down_blocks.6": "down_blocks.2",
"down_blocks.7": "down_blocks.2.downsamplers.0",
"down_blocks.8": "down_blocks.3",
"down_blocks.9": "mid_block",
# common
"conv_shortcut": "conv_shortcut.conv",
"res_blocks": "resnets",
"norm3.norm": "norm3",
"per_channel_statistics.mean-of-means": "latents_mean",
"per_channel_statistics.std-of-means": "latents_std",
}

VAE_SPECIAL_KEYS_REMAP = {
"per_channel_statistics.channel": remove_keys_,
"per_channel_statistics.mean-of-means": remove_keys_,
"per_channel_statistics.mean-of-stds": remove_keys_,
}


def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
state_dict = saved_dict
if "model" in saved_dict.keys():
state_dict = state_dict["model"]
if "module" in saved_dict.keys():
state_dict = state_dict["module"]
if "state_dict" in saved_dict.keys():
state_dict = state_dict["state_dict"]
return state_dict


def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
state_dict[new_key] = state_dict.pop(old_key)


def convert_transformer(
ckpt_path: str,
dtype: torch.dtype,
):
PREFIX_KEY = ""

original_state_dict = get_state_dict(load_file(ckpt_path))
transformer = LTXTransformer3DModel().to(dtype=dtype)

for key in list(original_state_dict.keys()):
new_key = key[len(PREFIX_KEY) :]
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)

for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)

transformer.load_state_dict(original_state_dict, strict=True)
return transformer


def convert_vae(ckpt_path: str, dtype: torch.dtype):
original_state_dict = get_state_dict(load_file(ckpt_path))
vae = AutoencoderKLLTX().to(dtype=dtype)

for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)

for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)

vae.load_state_dict(original_state_dict, strict=True)
return vae


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
)
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
parser.add_argument(
"--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
)
parser.add_argument(
"--typecast_text_encoder",
action="store_true",
default=False,
help="Whether or not to apply fp16/bf16 precision to text_encoder",
)
parser.add_argument("--save_pipeline", action="store_true")
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
return parser.parse_args()


DTYPE_MAPPING = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}

VARIANT_MAPPING = {
"fp32": None,
"fp16": "fp16",
"bf16": "bf16",
}


if __name__ == "__main__":
args = get_args()

transformer = None
dtype = DTYPE_MAPPING[args.dtype]
variant = VARIANT_MAPPING[args.dtype]

if args.save_pipeline:
assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None

if args.transformer_ckpt_path is not None:
transformer: LTXTransformer3DModel = convert_transformer(args.transformer_ckpt_path, dtype)
if not args.save_pipeline:
transformer.save_pretrained(
args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant
)

if args.vae_ckpt_path is not None:
vae: AutoencoderKLLTX = convert_vae(args.vae_ckpt_path, dtype)
if not args.save_pipeline:
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant)

if args.save_pipeline:
text_encoder_id = "google/t5-v1_1-xxl"
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)

if args.typecast_text_encoder:
text_encoder = text_encoder.to(dtype=dtype)

# Apparently, the conversion does not work anymore without this :shrug:
for param in text_encoder.parameters():
param.data = param.data.contiguous()

scheduler = FlowMatchEulerDiscreteScheduler(
use_dynamic_shifting=True,
base_shift=0.95,
max_shift=2.05,
base_image_seq_len=1024,
max_image_seq_len=4096,
shift_terminal=0.1,
)

pipe = LTXPipeline(
scheduler=scheduler,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
)

pipe.save_pretrained(args.output_path, safe_serialization=True, variant=variant, max_shard_size="5GB")
Loading
Loading