|
106 | 106 | ],
|
107 | 107 | "autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias",
|
108 | 108 | "autoencoder-dc-sana": "encoder.project_in.conv.bias",
|
| 109 | + "mochi-1-preview": ["model.diffusion_model.blocks.0.attn.qkv_x.weight", "blocks.0.attn.qkv_x.weight"], |
109 | 110 | }
|
110 | 111 |
|
111 | 112 | DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
|
159 | 160 | "autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"},
|
160 | 161 | "autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"},
|
161 | 162 | "autoencoder-dc-f32c32-sana": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"},
|
| 163 | + "mochi-1-preview": {"pretrained_model_name_or_path": "genmo/mochi-1-preview"}, |
162 | 164 | }
|
163 | 165 |
|
164 | 166 | # Use to configure model sample size when original config is provided
|
@@ -618,6 +620,9 @@ def infer_diffusers_model_type(checkpoint):
|
618 | 620 | else:
|
619 | 621 | model_type = "autoencoder-dc-f128c512"
|
620 | 622 |
|
| 623 | + elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["mochi-1-preview"]): |
| 624 | + model_type = "mochi-1-preview" |
| 625 | + |
621 | 626 | else:
|
622 | 627 | model_type = "v1"
|
623 | 628 |
|
@@ -1758,6 +1763,12 @@ def swap_scale_shift(weight, dim):
|
1758 | 1763 | return new_weight
|
1759 | 1764 |
|
1760 | 1765 |
|
| 1766 | +def swap_proj_gate(weight): |
| 1767 | + proj, gate = weight.chunk(2, dim=0) |
| 1768 | + new_weight = torch.cat([gate, proj], dim=0) |
| 1769 | + return new_weight |
| 1770 | + |
| 1771 | + |
1761 | 1772 | def get_attn2_layers(state_dict):
|
1762 | 1773 | attn2_layers = []
|
1763 | 1774 | for key in state_dict.keys():
|
@@ -2414,3 +2425,101 @@ def remap_proj_conv_(key: str, state_dict):
|
2414 | 2425 | handler_fn_inplace(key, converted_state_dict)
|
2415 | 2426 |
|
2416 | 2427 | return converted_state_dict
|
| 2428 | + |
| 2429 | + |
| 2430 | +def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): |
| 2431 | + new_state_dict = {} |
| 2432 | + |
| 2433 | + # Comfy checkpoints add this prefix |
| 2434 | + keys = list(checkpoint.keys()) |
| 2435 | + for k in keys: |
| 2436 | + if "model.diffusion_model." in k: |
| 2437 | + checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) |
| 2438 | + |
| 2439 | + # Convert patch_embed |
| 2440 | + new_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight") |
| 2441 | + new_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias") |
| 2442 | + |
| 2443 | + # Convert time_embed |
| 2444 | + new_state_dict["time_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight") |
| 2445 | + new_state_dict["time_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias") |
| 2446 | + new_state_dict["time_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight") |
| 2447 | + new_state_dict["time_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias") |
| 2448 | + new_state_dict["time_embed.pooler.to_kv.weight"] = checkpoint.pop("t5_y_embedder.to_kv.weight") |
| 2449 | + new_state_dict["time_embed.pooler.to_kv.bias"] = checkpoint.pop("t5_y_embedder.to_kv.bias") |
| 2450 | + new_state_dict["time_embed.pooler.to_q.weight"] = checkpoint.pop("t5_y_embedder.to_q.weight") |
| 2451 | + new_state_dict["time_embed.pooler.to_q.bias"] = checkpoint.pop("t5_y_embedder.to_q.bias") |
| 2452 | + new_state_dict["time_embed.pooler.to_out.weight"] = checkpoint.pop("t5_y_embedder.to_out.weight") |
| 2453 | + new_state_dict["time_embed.pooler.to_out.bias"] = checkpoint.pop("t5_y_embedder.to_out.bias") |
| 2454 | + new_state_dict["time_embed.caption_proj.weight"] = checkpoint.pop("t5_yproj.weight") |
| 2455 | + new_state_dict["time_embed.caption_proj.bias"] = checkpoint.pop("t5_yproj.bias") |
| 2456 | + |
| 2457 | + # Convert transformer blocks |
| 2458 | + num_layers = 48 |
| 2459 | + for i in range(num_layers): |
| 2460 | + block_prefix = f"transformer_blocks.{i}." |
| 2461 | + old_prefix = f"blocks.{i}." |
| 2462 | + |
| 2463 | + # norm1 |
| 2464 | + new_state_dict[block_prefix + "norm1.linear.weight"] = checkpoint.pop(old_prefix + "mod_x.weight") |
| 2465 | + new_state_dict[block_prefix + "norm1.linear.bias"] = checkpoint.pop(old_prefix + "mod_x.bias") |
| 2466 | + if i < num_layers - 1: |
| 2467 | + new_state_dict[block_prefix + "norm1_context.linear.weight"] = checkpoint.pop(old_prefix + "mod_y.weight") |
| 2468 | + new_state_dict[block_prefix + "norm1_context.linear.bias"] = checkpoint.pop(old_prefix + "mod_y.bias") |
| 2469 | + else: |
| 2470 | + new_state_dict[block_prefix + "norm1_context.linear_1.weight"] = checkpoint.pop( |
| 2471 | + old_prefix + "mod_y.weight" |
| 2472 | + ) |
| 2473 | + new_state_dict[block_prefix + "norm1_context.linear_1.bias"] = checkpoint.pop(old_prefix + "mod_y.bias") |
| 2474 | + |
| 2475 | + # Visual attention |
| 2476 | + qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_x.weight") |
| 2477 | + q, k, v = qkv_weight.chunk(3, dim=0) |
| 2478 | + |
| 2479 | + new_state_dict[block_prefix + "attn1.to_q.weight"] = q |
| 2480 | + new_state_dict[block_prefix + "attn1.to_k.weight"] = k |
| 2481 | + new_state_dict[block_prefix + "attn1.to_v.weight"] = v |
| 2482 | + new_state_dict[block_prefix + "attn1.norm_q.weight"] = checkpoint.pop(old_prefix + "attn.q_norm_x.weight") |
| 2483 | + new_state_dict[block_prefix + "attn1.norm_k.weight"] = checkpoint.pop(old_prefix + "attn.k_norm_x.weight") |
| 2484 | + new_state_dict[block_prefix + "attn1.to_out.0.weight"] = checkpoint.pop(old_prefix + "attn.proj_x.weight") |
| 2485 | + new_state_dict[block_prefix + "attn1.to_out.0.bias"] = checkpoint.pop(old_prefix + "attn.proj_x.bias") |
| 2486 | + |
| 2487 | + # Context attention |
| 2488 | + qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_y.weight") |
| 2489 | + q, k, v = qkv_weight.chunk(3, dim=0) |
| 2490 | + |
| 2491 | + new_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q |
| 2492 | + new_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k |
| 2493 | + new_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v |
| 2494 | + new_state_dict[block_prefix + "attn1.norm_added_q.weight"] = checkpoint.pop( |
| 2495 | + old_prefix + "attn.q_norm_y.weight" |
| 2496 | + ) |
| 2497 | + new_state_dict[block_prefix + "attn1.norm_added_k.weight"] = checkpoint.pop( |
| 2498 | + old_prefix + "attn.k_norm_y.weight" |
| 2499 | + ) |
| 2500 | + if i < num_layers - 1: |
| 2501 | + new_state_dict[block_prefix + "attn1.to_add_out.weight"] = checkpoint.pop( |
| 2502 | + old_prefix + "attn.proj_y.weight" |
| 2503 | + ) |
| 2504 | + new_state_dict[block_prefix + "attn1.to_add_out.bias"] = checkpoint.pop(old_prefix + "attn.proj_y.bias") |
| 2505 | + |
| 2506 | + # MLP |
| 2507 | + new_state_dict[block_prefix + "ff.net.0.proj.weight"] = swap_proj_gate( |
| 2508 | + checkpoint.pop(old_prefix + "mlp_x.w1.weight") |
| 2509 | + ) |
| 2510 | + new_state_dict[block_prefix + "ff.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_x.w2.weight") |
| 2511 | + if i < num_layers - 1: |
| 2512 | + new_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = swap_proj_gate( |
| 2513 | + checkpoint.pop(old_prefix + "mlp_y.w1.weight") |
| 2514 | + ) |
| 2515 | + new_state_dict[block_prefix + "ff_context.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_y.w2.weight") |
| 2516 | + |
| 2517 | + # Output layers |
| 2518 | + new_state_dict["norm_out.linear.weight"] = swap_scale_shift(checkpoint.pop("final_layer.mod.weight"), dim=0) |
| 2519 | + new_state_dict["norm_out.linear.bias"] = swap_scale_shift(checkpoint.pop("final_layer.mod.bias"), dim=0) |
| 2520 | + new_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight") |
| 2521 | + new_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias") |
| 2522 | + |
| 2523 | + new_state_dict["pos_frequencies"] = checkpoint.pop("pos_frequencies") |
| 2524 | + |
| 2525 | + return new_state_dict |
0 commit comments