|
21 | 21 |
|
22 | 22 |
|
23 | 23 | param_sharding_plan = { |
24 | | - "model.embed_tokens.weight": [Replicate()], |
25 | | - r"model.layers.\d+.input_layernorm.weight": [Replicate()], # MixtralRMSNorm |
26 | | - r"model.layers.\d+.self_attn.q_proj.weight": [Shard(0)], |
27 | | - r"model.layers.\d+.self_attn.k_proj.weight": [Shard(0)], |
28 | | - r"model.layers.\d+.self_attn.v_proj.weight": [Shard(0)], |
| 24 | + "mixtral_model.model.embed_tokens.weight": [Replicate()], |
| 25 | + r"mixtral_model.model.layers.\d+.input_layernorm.weight": [Replicate()], # MixtralRMSNorm |
| 26 | + r"mixtral_model.model.layers.\d+.self_attn.q_proj.weight": [Shard(0)], |
| 27 | + r"mixtral_model.model.layers.\d+.self_attn.k_proj.weight": [Shard(0)], |
| 28 | + r"mixtral_model.model.layers.\d+.self_attn.v_proj.weight": [Shard(0)], |
29 | 29 | # TODO: buggy, cos_cached or sin_cached can be updated or recreated if seqlen exceeds the max seqlen. |
30 | | - r"model.layers.\d+.self_attn.rotary_emb.layers.\d+.cos_cached": [Replicate()], |
31 | | - r"model.layers.\d+.self_attn.rotary_emb.layers.\d+.sin_cached": [Replicate()], |
32 | | - r"model.layers.\d+.self_attn.o_proj.weight": [Shard(1)], |
33 | | - r"model.layers.\d+.post_attention_layernorm.weight": [Replicate()], |
34 | | - r"model.layers.\d+.block_sparse_moe.gate.weight": [Replicate()], |
35 | | - r"model.layers.\d+.block_sparse_moe.experts.\d+.w1.weight": [Shard(0)], |
36 | | - r"model.layers.\d+.block_sparse_moe.experts.\d+.w3.weight": [Shard(0)], |
37 | | - r"model.layers.\d+.block_sparse_moe.experts.\d+.w2.weight": [Shard(1)], |
38 | | - "model.norm.weight": [Replicate()], |
| 30 | + r"mixtral_model.model.layers.\d+.self_attn.rotary_emb.layers.\d+.cos_cached": [Replicate()], |
| 31 | + r"mixtral_model.model.layers.\d+.self_attn.rotary_emb.layers.\d+.sin_cached": [Replicate()], |
| 32 | + r"mixtral_model.model.layers.\d+.self_attn.o_proj.weight": [Shard(1)], |
| 33 | + r"mixtral_model.model.layers.\d+.post_attention_layernorm.weight": [Replicate()], |
| 34 | + r"mixtral_model.model.layers.\d+.block_sparse_moe.gate.weight": [Replicate()], |
| 35 | + r"mixtral_model.model.layers.\d+.block_sparse_moe.experts.\d+.w1.weight": [Shard(0)], |
| 36 | + r"mixtral_model.model.layers.\d+.block_sparse_moe.experts.\d+.w3.weight": [Shard(0)], |
| 37 | + r"mixtral_model.model.layers.\d+.block_sparse_moe.experts.\d+.w2.weight": [Shard(1)], |
| 38 | + "mixtral_model.model.norm.weight": [Replicate()], |
39 | 39 | } |
40 | 40 |
|
41 | 41 | fwd_resharding_plan = { |
42 | 42 | # TODO: buggy: attn mask is torch.Tensor, in training, it's a None |
43 | 43 | r".input": {"input_ids": [Replicate()], "attention_mask": [Replicate()]}, |
44 | | - "model.embed_tokens.input": [[Replicate()]], |
| 44 | + "mixtral_model.model.embed_tokens.input": [[Replicate()]], |
45 | 45 | # No SP |
46 | 46 | # r"layers.\d+.input_layernorm.input": [[Replicate()]], |
47 | 47 | # r"layers.\d+.input_layernorm.output": [[Replicate()]], |
48 | 48 | # SP |
49 | | - r"model.layers.\d+.input_layernorm.input": [[Shard(1)]], |
50 | | - r"model.layers.\d+.input_layernorm.output": [[Shard(1)]], |
51 | | - r"model.layers.\d+.self_attn.input": [[Replicate()]], |
52 | | - r"model.layers.\d+.self_attn.output": {"attn_output": [Replicate()], "attn_weights": None, "past_key_value": None}, |
53 | | - r"model.layers.\d+.self_attn.o_proj.output": [[Replicate()]], |
| 49 | + r"mixtral_model.model.layers.\d+.input_layernorm.input": [[Shard(1)]], |
| 50 | + r"mixtral_model.model.layers.\d+.input_layernorm.output": [[Shard(1)]], |
| 51 | + r"mixtral_model.model.layers.\d+.self_attn.input": [[Replicate()]], |
| 52 | + r"mixtral_model.model.layers.\d+.self_attn.output": { |
| 53 | + "attn_output": [Replicate()], |
| 54 | + "attn_weights": None, |
| 55 | + "past_key_value": None, |
| 56 | + }, |
| 57 | + r"mixtral_model.model.layers.\d+.self_attn.o_proj.output": [[Replicate()]], |
54 | 58 | # No SP |
55 | 59 | # r"model.layers.\d+.post_attention_layernorm.input": [[Replicate()]], |
56 | 60 | # r"model.layers.\d+.post_attention_layernorm.output": [[Replicate()]], |
57 | 61 | # SP |
58 | | - r"model.layers.\d+.post_attention_layernorm.input": [[Shard(1)]], |
59 | | - r"model.layers.\d+.post_attention_layernorm.output": [[Shard(1)]], |
60 | | - r"model.layers.\d+.block_sparse_moe.input": [[Replicate()]], |
61 | | - r"model.layers.\d+.block_sparse_moe.gate.output": [[Replicate()]], |
62 | | - r"model.layers.\d+.block_sparse_moe.output": {"final_hidden_states": [Replicate()], "router_logits": [Replicate()]}, |
63 | | - r"model.layers.\d+.block_sparse_moe.experts.\d+.w1.input": [[Replicate()]], |
64 | | - r"model.layers.\d+.block_sparse_moe.experts.\d+.w3.input": [[Replicate()]], |
65 | | - r"model.layers.\d+.block_sparse_moe.experts.\d+.w2.output": [[Replicate()]], |
66 | | - "model.norm.input": [[Replicate()]], |
| 62 | + r"mixtral_model.model.layers.\d+.post_attention_layernorm.input": [[Shard(1)]], |
| 63 | + r"mixtral_model.model.layers.\d+.post_attention_layernorm.output": [[Shard(1)]], |
| 64 | + r"mixtral_model.model.layers.\d+.block_sparse_moe.input": [[Replicate()]], |
| 65 | + r"mixtral_model.model.layers.\d+.block_sparse_moe.gate.output": [[Replicate()]], |
| 66 | + r"mixtral_model.model.layers.\d+.block_sparse_moe.output": { |
| 67 | + "final_hidden_states": [Replicate()], |
| 68 | + "router_logits": [Replicate()], |
| 69 | + }, |
| 70 | + r"mixtral_model.model.layers.\d+.block_sparse_moe.experts.\d+.w1.input": [[Replicate()]], |
| 71 | + r"mixtral_model.model.layers.\d+.block_sparse_moe.experts.\d+.w3.input": [[Replicate()]], |
| 72 | + r"mixtral_model.model.layers.\d+.block_sparse_moe.experts.\d+.w2.output": [[Replicate()]], |
| 73 | + "mixtral_model.model.norm.input": [[Replicate()]], |
67 | 74 | } |
68 | 75 |
|
69 | 76 | mixtral_plan = {"parameter": param_sharding_plan, "forward": fwd_resharding_plan} |
0 commit comments