Skip to content

Commit 26d7b84

Browse files
DN6sayakpaul
authored andcommitted
[Single File] Add single file support for Mochi Transformer (#10268)
update
1 parent e854770 commit 26d7b84

File tree

3 files changed

+116
-1
lines changed

3 files changed

+116
-1
lines changed

src/diffusers/loaders/single_file_model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
convert_ldm_vae_checkpoint,
3333
convert_ltx_transformer_checkpoint_to_diffusers,
3434
convert_ltx_vae_checkpoint_to_diffusers,
35+
convert_mochi_transformer_checkpoint_to_diffusers,
3536
convert_sd3_transformer_checkpoint_to_diffusers,
3637
convert_stable_cascade_unet_single_file_to_diffusers,
3738
create_controlnet_diffusers_config_from_ldm,
@@ -96,6 +97,10 @@
9697
"default_subfolder": "vae",
9798
},
9899
"AutoencoderDC": {"checkpoint_mapping_fn": convert_autoencoder_dc_checkpoint_to_diffusers},
100+
"MochiTransformer3DModel": {
101+
"checkpoint_mapping_fn": convert_mochi_transformer_checkpoint_to_diffusers,
102+
"default_subfolder": "transformer",
103+
},
99104
}
100105

101106

src/diffusers/loaders/single_file_utils.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@
106106
],
107107
"autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias",
108108
"autoencoder-dc-sana": "encoder.project_in.conv.bias",
109+
"mochi-1-preview": ["model.diffusion_model.blocks.0.attn.qkv_x.weight", "blocks.0.attn.qkv_x.weight"],
109110
}
110111

111112
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -159,6 +160,7 @@
159160
"autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"},
160161
"autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"},
161162
"autoencoder-dc-f32c32-sana": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"},
163+
"mochi-1-preview": {"pretrained_model_name_or_path": "genmo/mochi-1-preview"},
162164
}
163165

164166
# Use to configure model sample size when original config is provided
@@ -618,6 +620,9 @@ def infer_diffusers_model_type(checkpoint):
618620
else:
619621
model_type = "autoencoder-dc-f128c512"
620622

623+
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["mochi-1-preview"]):
624+
model_type = "mochi-1-preview"
625+
621626
else:
622627
model_type = "v1"
623628

@@ -1758,6 +1763,12 @@ def swap_scale_shift(weight, dim):
17581763
return new_weight
17591764

17601765

1766+
def swap_proj_gate(weight):
1767+
proj, gate = weight.chunk(2, dim=0)
1768+
new_weight = torch.cat([gate, proj], dim=0)
1769+
return new_weight
1770+
1771+
17611772
def get_attn2_layers(state_dict):
17621773
attn2_layers = []
17631774
for key in state_dict.keys():
@@ -2414,3 +2425,101 @@ def remap_proj_conv_(key: str, state_dict):
24142425
handler_fn_inplace(key, converted_state_dict)
24152426

24162427
return converted_state_dict
2428+
2429+
2430+
def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
2431+
new_state_dict = {}
2432+
2433+
# Comfy checkpoints add this prefix
2434+
keys = list(checkpoint.keys())
2435+
for k in keys:
2436+
if "model.diffusion_model." in k:
2437+
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
2438+
2439+
# Convert patch_embed
2440+
new_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
2441+
new_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
2442+
2443+
# Convert time_embed
2444+
new_state_dict["time_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight")
2445+
new_state_dict["time_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
2446+
new_state_dict["time_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight")
2447+
new_state_dict["time_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
2448+
new_state_dict["time_embed.pooler.to_kv.weight"] = checkpoint.pop("t5_y_embedder.to_kv.weight")
2449+
new_state_dict["time_embed.pooler.to_kv.bias"] = checkpoint.pop("t5_y_embedder.to_kv.bias")
2450+
new_state_dict["time_embed.pooler.to_q.weight"] = checkpoint.pop("t5_y_embedder.to_q.weight")
2451+
new_state_dict["time_embed.pooler.to_q.bias"] = checkpoint.pop("t5_y_embedder.to_q.bias")
2452+
new_state_dict["time_embed.pooler.to_out.weight"] = checkpoint.pop("t5_y_embedder.to_out.weight")
2453+
new_state_dict["time_embed.pooler.to_out.bias"] = checkpoint.pop("t5_y_embedder.to_out.bias")
2454+
new_state_dict["time_embed.caption_proj.weight"] = checkpoint.pop("t5_yproj.weight")
2455+
new_state_dict["time_embed.caption_proj.bias"] = checkpoint.pop("t5_yproj.bias")
2456+
2457+
# Convert transformer blocks
2458+
num_layers = 48
2459+
for i in range(num_layers):
2460+
block_prefix = f"transformer_blocks.{i}."
2461+
old_prefix = f"blocks.{i}."
2462+
2463+
# norm1
2464+
new_state_dict[block_prefix + "norm1.linear.weight"] = checkpoint.pop(old_prefix + "mod_x.weight")
2465+
new_state_dict[block_prefix + "norm1.linear.bias"] = checkpoint.pop(old_prefix + "mod_x.bias")
2466+
if i < num_layers - 1:
2467+
new_state_dict[block_prefix + "norm1_context.linear.weight"] = checkpoint.pop(old_prefix + "mod_y.weight")
2468+
new_state_dict[block_prefix + "norm1_context.linear.bias"] = checkpoint.pop(old_prefix + "mod_y.bias")
2469+
else:
2470+
new_state_dict[block_prefix + "norm1_context.linear_1.weight"] = checkpoint.pop(
2471+
old_prefix + "mod_y.weight"
2472+
)
2473+
new_state_dict[block_prefix + "norm1_context.linear_1.bias"] = checkpoint.pop(old_prefix + "mod_y.bias")
2474+
2475+
# Visual attention
2476+
qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_x.weight")
2477+
q, k, v = qkv_weight.chunk(3, dim=0)
2478+
2479+
new_state_dict[block_prefix + "attn1.to_q.weight"] = q
2480+
new_state_dict[block_prefix + "attn1.to_k.weight"] = k
2481+
new_state_dict[block_prefix + "attn1.to_v.weight"] = v
2482+
new_state_dict[block_prefix + "attn1.norm_q.weight"] = checkpoint.pop(old_prefix + "attn.q_norm_x.weight")
2483+
new_state_dict[block_prefix + "attn1.norm_k.weight"] = checkpoint.pop(old_prefix + "attn.k_norm_x.weight")
2484+
new_state_dict[block_prefix + "attn1.to_out.0.weight"] = checkpoint.pop(old_prefix + "attn.proj_x.weight")
2485+
new_state_dict[block_prefix + "attn1.to_out.0.bias"] = checkpoint.pop(old_prefix + "attn.proj_x.bias")
2486+
2487+
# Context attention
2488+
qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_y.weight")
2489+
q, k, v = qkv_weight.chunk(3, dim=0)
2490+
2491+
new_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q
2492+
new_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k
2493+
new_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v
2494+
new_state_dict[block_prefix + "attn1.norm_added_q.weight"] = checkpoint.pop(
2495+
old_prefix + "attn.q_norm_y.weight"
2496+
)
2497+
new_state_dict[block_prefix + "attn1.norm_added_k.weight"] = checkpoint.pop(
2498+
old_prefix + "attn.k_norm_y.weight"
2499+
)
2500+
if i < num_layers - 1:
2501+
new_state_dict[block_prefix + "attn1.to_add_out.weight"] = checkpoint.pop(
2502+
old_prefix + "attn.proj_y.weight"
2503+
)
2504+
new_state_dict[block_prefix + "attn1.to_add_out.bias"] = checkpoint.pop(old_prefix + "attn.proj_y.bias")
2505+
2506+
# MLP
2507+
new_state_dict[block_prefix + "ff.net.0.proj.weight"] = swap_proj_gate(
2508+
checkpoint.pop(old_prefix + "mlp_x.w1.weight")
2509+
)
2510+
new_state_dict[block_prefix + "ff.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_x.w2.weight")
2511+
if i < num_layers - 1:
2512+
new_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = swap_proj_gate(
2513+
checkpoint.pop(old_prefix + "mlp_y.w1.weight")
2514+
)
2515+
new_state_dict[block_prefix + "ff_context.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_y.w2.weight")
2516+
2517+
# Output layers
2518+
new_state_dict["norm_out.linear.weight"] = swap_scale_shift(checkpoint.pop("final_layer.mod.weight"), dim=0)
2519+
new_state_dict["norm_out.linear.bias"] = swap_scale_shift(checkpoint.pop("final_layer.mod.bias"), dim=0)
2520+
new_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
2521+
new_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
2522+
2523+
new_state_dict["pos_frequencies"] = checkpoint.pop("pos_frequencies")
2524+
2525+
return new_state_dict

src/diffusers/models/transformers/transformer_mochi.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from ...configuration_utils import ConfigMixin, register_to_config
2222
from ...loaders import PeftAdapterMixin
23+
from ...loaders.single_file_model import FromOriginalModelMixin
2324
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
2425
from ...utils.torch_utils import maybe_allow_in_graph
2526
from ..attention import FeedForward
@@ -304,7 +305,7 @@ def forward(
304305

305306

306307
@maybe_allow_in_graph
307-
class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
308+
class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
308309
r"""
309310
A Transformer model for video-like data introduced in [Mochi](https://huggingface.co/genmo/mochi-1-preview).
310311

0 commit comments

Comments
 (0)