Skip to content

Commit ac76ffa

Browse files
authored
[moe] feat: enabling expert parallelism in veScale (#59)
## Overview veScale provides an efficient framework for training Mixture of Experts (MoE) models using expert parallelism. Expert parallelism can be deployed with the `parallelize_experts()` function, which simplifies the process of distributing and managing workload during MoE training. ### Function Signature ```python model = parallelize_experts( module: nn.Module, experts_expr: Union[str, List[str]], experts_allocator: vescale.moe.ExpertsAllocator, token_dispatcher: vescale.moe.TokenDispatcher, config: Dict, ) ``` ### Parameters - **`module`**: The training model (an instance of `nn.Module`) to be parallelized. - **`experts_expr`**: Specifies the paths to the expert modules. Can be a string or a list of strings. - **`experts_allocator`**: An instance of `ExpertsAllocator`, used for managing expert parameter allocation. - **`token_dispatcher`**: An instance of `TokenDispatcher`, responsible for token scheduling and distribution. - **`config`**: A dictionary containing the MoE training configuration, including layer count, number of experts, and other relevant settings. ## Custom Scheduling veScale allows users to define custom scheduling strategies for expert parallelism by implementing the following components: - **`ExpertsAllocator`**: Manages expert parameter allocation. It can use `collect_performance()` to profile and dynamically adjust the DP x TP device mesh for each expert. By default, veScale shards all expert parameters across devices using tensor parallelism. - **`TokenDispatcher`**: Handles token distribution. Using `assign_task()`, it determines workload allocation (e.g., expert IDs and token weights) and adjusts scheduling with `collect_performance()`. The default implementation randomly assigns tokens to a single DP rank for the selected expert. ## Optimizer Support Since veScale supports dynamic placement of expert parameters, a dedicated optimizer, `MoEOptimizer`, is required. This optimizer handles the redistribution of expert parameters and their states efficiently. Future updates will integrate these functionalities into optimizers for static parameters to streamline the process. ## Getting Started ### Data Preparation Prepare the Shakespeare dataset by running: ```bash cd data/shakespeare/ python3 prepare.py cd ../.. ``` ### Training Command ``` torchrun --standalone --nproc_per_node={GPU_CNT} mixtral_train.py --dp={dp_size} --tp={tp_size} --max_iters={max_iters} ```
1 parent b4b1686 commit ac76ffa

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+2444
-132
lines changed

examples/llama2_4D_finetune/llama_train.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,20 @@
3737
from data_loader import DataLoader
3838

3939

40+
class Net(torch.nn.Module):
41+
def __init__(self, path, torch_dtype):
42+
super().__init__()
43+
self.llama_model = LlamaForCausalLM.from_pretrained(path, torch_dtype=torch_dtype)
44+
self.loss_fn = torch.nn.CrossEntropyLoss()
45+
46+
def forward(self, input_ids, labels):
47+
logits = self.llama_model(input_ids).logits
48+
logits = logits.flatten(end_dim=-2)
49+
labels = labels.flatten()
50+
loss = self.loss_fn(logits, labels)
51+
return loss
52+
53+
4054
def estimate_llama2(config, bsz, sqence_length):
4155
embed = 4 * bsz * sqence_length * config.hidden_size
4256
ff = 3 * 2 * config.hidden_size * config.intermediate_size * bsz * sqence_length
@@ -53,7 +67,7 @@ def run_llama2(args):
5367
local_rank = int(os.environ["LOCAL_RANK"])
5468
world_size = int(os.environ["WORLD_SIZE"])
5569
rank = int(os.environ["RANK"])
56-
device = f"cuda:{rank}"
70+
device = f"cuda:{local_rank}"
5771
torch.cuda.set_device(device)
5872
dist.init_process_group(backend="nccl", world_size=world_size, rank=rank)
5973
VESCALE_DEVICE_MESH.init_device_mesh(device, (args.dp, args.tp), mesh_dim_names=["DP", "TP"])
@@ -77,8 +91,8 @@ def run_llama2(args):
7791
"bfloat16": torch.bfloat16,
7892
}[args.dtype]
7993

80-
model = LlamaForCausalLM.from_pretrained("openlm-research/open_llama_3b", torch_dtype=ptdtype)
81-
llama_config = model.config
94+
model = Net("openlm-research/open_llama_3b", torch_dtype=ptdtype)
95+
llama_config = model.llama_model.config
8296
if rank == 0:
8397
print(model)
8498
print(llama_config)
@@ -165,7 +179,7 @@ def estimate_loss():
165179
losses = torch.zeros(args.eval_iters // factor).to(device)
166180
for k in range(args.eval_iters // factor):
167181
X, Y = data_loader.get_batch(split, args.bsz * factor, factor * args.bsz // args.dp)
168-
loss = model(X, labels=Y).loss
182+
loss = model(X, Y)
169183
if world_size > 1:
170184
losses[k] = loss.to_local().item()
171185
else:
@@ -198,7 +212,7 @@ def estimate_loss():
198212
start_epoch.record()
199213
if world_size > 1:
200214
model.zero_grad_buffer()
201-
loss = model(X, labels=Y).loss
215+
loss = model(X, Y)
202216
loss.backward()
203217
grad_norm = -1
204218
if world_size == 1 and args.grad_clip > 0:

examples/llama2_4D_finetune/sharding_plan.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,19 +45,19 @@
4545

4646
# forward resharding plan for the whole open llama model
4747
model_fwd_resharding_plan = {
48-
"model.input": [[Replicate()]],
49-
"model.embed_tokens.output": [[Shard(1)]],
50-
"model.norm.input": [[Shard(1)]],
51-
"model.output": {
48+
"llama_model.model.input": [[Replicate()]],
49+
"llama_model.model.embed_tokens.output": [[Shard(1)]],
50+
"llama_model.model.norm.input": [[Shard(1)]],
51+
"llama_model.model.output": {
5252
"last_hidden_state": [Replicate()],
5353
},
54-
**{rf"model.layers.\d+.{k}": v for k, v in _decoder_fwd_resharding_plan.items()},
54+
**{rf"llama_model.model.layers.\d+.{k}": v for k, v in _decoder_fwd_resharding_plan.items()},
5555
}
5656

5757
# model parameter sharding plan for the whole open llama model
5858
model_param_sharding_plan = {
59-
"model.embed_tokens.weight": [Shard(1)],
60-
**{rf"model.layers.\d+.{k}": v for k, v in _decoder_param_sharding_plan.items()},
59+
"llama_model.model.embed_tokens.weight": [Shard(1)],
60+
**{rf"llama_model.model.layers.\d+.{k}": v for k, v in _decoder_param_sharding_plan.items()},
6161
}
6262

6363
llama2_plan = {"parameter": model_param_sharding_plan, "forward": model_fwd_resharding_plan}

examples/mixtral_4D_benchmark/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ from HuggingFace without any model code modifications.
1111

1212
### Single Machine 8 cards
1313
```
14-
torchrun --nproc-per-node=8 --nnodes=1 --master-port=42516 -- examples/mixtral_4D_benchmark/mixtral_train.py --num_hidden_layers=16
14+
torchrun --nproc-per-node=8 --standalone examples/mixtral_4D_benchmark/mixtral_train.py --num_hidden_layers=16
1515
```
1616
This will start a 8-cards MFU benchmark for Mixtral with veScale with dp=1 and tp=8.
1717

examples/mixtral_4D_benchmark/mixtral_train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from vescale.optim.distributed_optimizer import DistributedOptimizer
2828
from vescale.initialize.deferred_init import deferred_init, is_deferred
2929

30-
from transformers.models.mixtral.modeling_mixtral import MixtralModel
30+
from transformers.models.mixtral.modeling_mixtral import MixtralModel, MixtralSparseMoeBlock
3131
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
3232
from sharding_plan import mixtral_plan
3333

@@ -84,7 +84,7 @@ def run_mixtral(args):
8484
accumulate_allreduce_grads_in_fp32=True,
8585
overlap_grad_reduce=False,
8686
use_distributed_optimizer=True,
87-
whitelist_module_types=[MixtralSparseMoeBlock],
87+
module_to_enforce=[MixtralSparseMoeBlock],
8888
)
8989

9090
doptim = DistributedOptimizer(

examples/mixtral_4D_benchmark/sharding_plan.py

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -21,49 +21,49 @@
2121

2222

2323
param_sharding_plan = {
24-
"embed_tokens.weight": [Replicate()],
25-
r"layers.\d+.input_layernorm.weight": [Replicate()], # MixtralRMSNorm
26-
r"layers.\d+.self_attn.q_proj.weight": [Shard(0)],
27-
r"layers.\d+.self_attn.k_proj.weight": [Shard(0)],
28-
r"layers.\d+.self_attn.v_proj.weight": [Shard(0)],
24+
"model.embed_tokens.weight": [Replicate()],
25+
r"model.layers.\d+.input_layernorm.weight": [Replicate()], # MixtralRMSNorm
26+
r"model.layers.\d+.self_attn.q_proj.weight": [Shard(0)],
27+
r"model.layers.\d+.self_attn.k_proj.weight": [Shard(0)],
28+
r"model.layers.\d+.self_attn.v_proj.weight": [Shard(0)],
2929
# TODO: buggy, cos_cached or sin_cached can be updated or recreated if seqlen exceeds the max seqlen.
30-
r"layers.\d+.self_attn.rotary_emb.layers.\d+.cos_cached": [Replicate()],
31-
r"layers.\d+.self_attn.rotary_emb.layers.\d+.sin_cached": [Replicate()],
32-
r"layers.\d+.self_attn.o_proj.weight": [Shard(1)],
33-
r"layers.\d+.post_attention_layernorm.weight": [Replicate()],
34-
r"layers.\d+.block_sparse_moe.gate.weight": [Replicate()],
35-
r"layers.\d+.block_sparse_moe.experts.\d+.w1.weight": [Shard(0)],
36-
r"layers.\d+.block_sparse_moe.experts.\d+.w3.weight": [Shard(0)],
37-
r"layers.\d+.block_sparse_moe.experts.\d+.w2.weight": [Shard(1)],
38-
"norm.weight": [Replicate()],
30+
r"model.layers.\d+.self_attn.rotary_emb.layers.\d+.cos_cached": [Replicate()],
31+
r"model.layers.\d+.self_attn.rotary_emb.layers.\d+.sin_cached": [Replicate()],
32+
r"model.layers.\d+.self_attn.o_proj.weight": [Shard(1)],
33+
r"model.layers.\d+.post_attention_layernorm.weight": [Replicate()],
34+
r"model.layers.\d+.block_sparse_moe.gate.weight": [Replicate()],
35+
r"model.layers.\d+.block_sparse_moe.experts.\d+.w1.weight": [Shard(0)],
36+
r"model.layers.\d+.block_sparse_moe.experts.\d+.w3.weight": [Shard(0)],
37+
r"model.layers.\d+.block_sparse_moe.experts.\d+.w2.weight": [Shard(1)],
38+
"model.norm.weight": [Replicate()],
3939
}
4040

4141
fwd_resharding_plan = {
4242
# TODO: buggy: attn mask is torch.Tensor, in training, it's a None
4343
r".input": {"input_ids": [Replicate()], "attention_mask": [Replicate()]},
44-
"embed_tokens.input": [[Replicate()]],
44+
"model.embed_tokens.input": [[Replicate()]],
4545
# No SP
4646
# r"layers.\d+.input_layernorm.input": [[Replicate()]],
4747
# r"layers.\d+.input_layernorm.output": [[Replicate()]],
4848
# SP
49-
r"layers.\d+.input_layernorm.input": [[Shard(1)]],
50-
r"layers.\d+.input_layernorm.output": [[Shard(1)]],
51-
r"layers.\d+.self_attn.input": [[Replicate()]],
52-
r"layers.\d+.self_attn.output": {"attn_output": [Replicate()], "attn_weights": None, "past_key_value": None},
53-
r"layers.\d+.self_attn.o_proj.output": [[Replicate()]],
49+
r"model.layers.\d+.input_layernorm.input": [[Shard(1)]],
50+
r"model.layers.\d+.input_layernorm.output": [[Shard(1)]],
51+
r"model.layers.\d+.self_attn.input": [[Replicate()]],
52+
r"model.layers.\d+.self_attn.output": {"attn_output": [Replicate()], "attn_weights": None, "past_key_value": None},
53+
r"model.layers.\d+.self_attn.o_proj.output": [[Replicate()]],
5454
# No SP
5555
# r"layers.\d+.post_attention_layernorm.input": [[Replicate()]],
5656
# r"layers.\d+.post_attention_layernorm.output": [[Replicate()]],
5757
# SP
58-
r"layers.\d+.post_attention_layernorm.input": [[Shard(1)]],
59-
r"layers.\d+.post_attention_layernorm.output": [[Shard(1)]],
60-
r"layers.\d+.block_sparse_moe.input": [[Replicate()]],
61-
r"layers.\d+.block_sparse_moe.gate.output": [[Replicate()]],
62-
r"layers.\d+.block_sparse_moe.output": {"final_hidden_states": [Replicate()], "router_logits": [Replicate()]},
63-
r"layers.\d+.block_sparse_moe.experts.\d+.w1.input": [[Replicate()]],
64-
r"layers.\d+.block_sparse_moe.experts.\d+.w3.input": [[Replicate()]],
65-
r"layers.\d+.block_sparse_moe.experts.\d+.w2.output": [[Replicate()]],
66-
"norm.input": [[Replicate()]],
58+
r"model.layers.\d+.post_attention_layernorm.input": [[Shard(1)]],
59+
r"model.layers.\d+.post_attention_layernorm.output": [[Shard(1)]],
60+
r"model.layers.\d+.block_sparse_moe.input": [[Replicate()]],
61+
r"model.layers.\d+.block_sparse_moe.gate.output": [[Replicate()]],
62+
r"model.layers.\d+.block_sparse_moe.output": {"final_hidden_states": [Replicate()], "router_logits": [Replicate()]},
63+
r"model.layers.\d+.block_sparse_moe.experts.\d+.w1.input": [[Replicate()]],
64+
r"model.layers.\d+.block_sparse_moe.experts.\d+.w3.input": [[Replicate()]],
65+
r"model.layers.\d+.block_sparse_moe.experts.\d+.w2.output": [[Replicate()]],
66+
"model.norm.input": [[Replicate()]],
6767
}
6868

6969
mixtral_plan = {"parameter": param_sharding_plan, "forward": fwd_resharding_plan}

examples/mixtral_4D_training/mixtral_train.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,20 @@
3737
from data_loader import DataLoader
3838

3939

40+
class Net(torch.nn.Module):
41+
def __init__(self, mixtral_config):
42+
super().__init__()
43+
self.mixtral_model = MixtralForCausalLM(mixtral_config)
44+
self.loss_fn = torch.nn.CrossEntropyLoss()
45+
46+
def forward(self, input_ids, labels):
47+
logits = self.mixtral_model(input_ids).logits
48+
logits = logits.flatten(end_dim=-2)
49+
labels = labels.flatten()
50+
loss = self.loss_fn(logits, labels)
51+
return loss
52+
53+
4054
def estimate_mixtral(config, bsz, sqence_length):
4155
embed = 4 * bsz * sqence_length * config.hidden_size
4256
# MixtralMoE consists of 3 linear layers.
@@ -57,7 +71,7 @@ def run_mixtral(args):
5771
local_rank = int(os.environ["LOCAL_RANK"])
5872
world_size = int(os.environ["WORLD_SIZE"])
5973
rank = int(os.environ["RANK"])
60-
device = f"cuda:{rank}"
74+
device = f"cuda:{local_rank}"
6175
torch.cuda.set_device(device)
6276
dist.init_process_group(backend="nccl", world_size=world_size, rank=rank)
6377
VESCALE_DEVICE_MESH.init_device_mesh(device, (args.dp, args.tp), mesh_dim_names=["DP", "TP"])
@@ -90,7 +104,7 @@ def run_mixtral(args):
90104
)
91105

92106
if world_size > 1:
93-
model = MixtralForCausalLM(mixtral_config)
107+
model = Net(mixtral_config)
94108
model.to(ptdtype)
95109

96110
model = parallelize_module(
@@ -104,11 +118,11 @@ def run_mixtral(args):
104118
model,
105119
VESCALE_DEVICE_MESH["DP"],
106120
accumulate_allreduce_grads_in_fp32=False,
107-
use_distributed_optimizer=True,
108-
whitelist_module_types=[MixtralSparseMoeBlock],
121+
use_distributed_optimizer=args.use_DO,
122+
module_to_enforce=[MixtralSparseMoeBlock],
109123
)
110124
else:
111-
model = MixtralForCausalLM(mixtral_config).to(device)
125+
model = Net(mixtral_config).to(device)
112126
model.to(ptdtype)
113127
print(f"rank {rank} cuda.rng_state {torch.cuda.get_rng_state().view(torch.int64)}")
114128

@@ -170,7 +184,7 @@ def estimate_loss():
170184
losses = torch.zeros(args.eval_iters // factor).to(device)
171185
for k in range(args.eval_iters // factor):
172186
X, Y = data_loader.get_batch(split, args.bsz * factor, factor * args.bsz // args.dp)
173-
loss = model(X, labels=Y).loss
187+
loss = model(X, Y)
174188
if world_size > 1:
175189
losses[k] = loss.to_local().item()
176190
else:
@@ -203,7 +217,7 @@ def estimate_loss():
203217
start_epoch.record()
204218
if world_size > 1:
205219
model.zero_grad_buffer()
206-
loss = model(X, labels=Y).loss
220+
loss = model(X, Y)
207221
loss.backward()
208222
grad_norm = -1
209223
if world_size == 1 and args.grad_clip > 0:
@@ -274,11 +288,6 @@ def parse_args():
274288
parser.add_argument("--num_hidden_layers", type=int, default=2)
275289
parser.add_argument("--num_attention_heads", type=int, default=8)
276290
parser.add_argument("--num_key_value_heads", type=int, default=8)
277-
# parser.add_argument("--hidden_size", type=int, default=4096)
278-
# parser.add_argument("--intermediate_size", type=int, default=14336)
279-
# parser.add_argument("--num_hidden_layers", type=int, default=16)
280-
# parser.add_argument("--num_attention_heads", type=int, default=32)
281-
# parser.add_argument("--num_key_value_heads", type=int, default=8)
282291

283292
# Optimizer related
284293
parser.add_argument("--use_DO", type=bool, default=True)

examples/mixtral_4D_training/sharding_plan.py

Lines changed: 36 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -21,49 +21,56 @@
2121

2222

2323
param_sharding_plan = {
24-
"model.embed_tokens.weight": [Replicate()],
25-
r"model.layers.\d+.input_layernorm.weight": [Replicate()], # MixtralRMSNorm
26-
r"model.layers.\d+.self_attn.q_proj.weight": [Shard(0)],
27-
r"model.layers.\d+.self_attn.k_proj.weight": [Shard(0)],
28-
r"model.layers.\d+.self_attn.v_proj.weight": [Shard(0)],
24+
"mixtral_model.model.embed_tokens.weight": [Replicate()],
25+
r"mixtral_model.model.layers.\d+.input_layernorm.weight": [Replicate()], # MixtralRMSNorm
26+
r"mixtral_model.model.layers.\d+.self_attn.q_proj.weight": [Shard(0)],
27+
r"mixtral_model.model.layers.\d+.self_attn.k_proj.weight": [Shard(0)],
28+
r"mixtral_model.model.layers.\d+.self_attn.v_proj.weight": [Shard(0)],
2929
# TODO: buggy, cos_cached or sin_cached can be updated or recreated if seqlen exceeds the max seqlen.
30-
r"model.layers.\d+.self_attn.rotary_emb.layers.\d+.cos_cached": [Replicate()],
31-
r"model.layers.\d+.self_attn.rotary_emb.layers.\d+.sin_cached": [Replicate()],
32-
r"model.layers.\d+.self_attn.o_proj.weight": [Shard(1)],
33-
r"model.layers.\d+.post_attention_layernorm.weight": [Replicate()],
34-
r"model.layers.\d+.block_sparse_moe.gate.weight": [Replicate()],
35-
r"model.layers.\d+.block_sparse_moe.experts.\d+.w1.weight": [Shard(0)],
36-
r"model.layers.\d+.block_sparse_moe.experts.\d+.w3.weight": [Shard(0)],
37-
r"model.layers.\d+.block_sparse_moe.experts.\d+.w2.weight": [Shard(1)],
38-
"model.norm.weight": [Replicate()],
30+
r"mixtral_model.model.layers.\d+.self_attn.rotary_emb.layers.\d+.cos_cached": [Replicate()],
31+
r"mixtral_model.model.layers.\d+.self_attn.rotary_emb.layers.\d+.sin_cached": [Replicate()],
32+
r"mixtral_model.model.layers.\d+.self_attn.o_proj.weight": [Shard(1)],
33+
r"mixtral_model.model.layers.\d+.post_attention_layernorm.weight": [Replicate()],
34+
r"mixtral_model.model.layers.\d+.block_sparse_moe.gate.weight": [Replicate()],
35+
r"mixtral_model.model.layers.\d+.block_sparse_moe.experts.\d+.w1.weight": [Shard(0)],
36+
r"mixtral_model.model.layers.\d+.block_sparse_moe.experts.\d+.w3.weight": [Shard(0)],
37+
r"mixtral_model.model.layers.\d+.block_sparse_moe.experts.\d+.w2.weight": [Shard(1)],
38+
"mixtral_model.model.norm.weight": [Replicate()],
3939
}
4040

4141
fwd_resharding_plan = {
4242
# TODO: buggy: attn mask is torch.Tensor, in training, it's a None
4343
r".input": {"input_ids": [Replicate()], "attention_mask": [Replicate()]},
44-
"model.embed_tokens.input": [[Replicate()]],
44+
"mixtral_model.model.embed_tokens.input": [[Replicate()]],
4545
# No SP
4646
# r"layers.\d+.input_layernorm.input": [[Replicate()]],
4747
# r"layers.\d+.input_layernorm.output": [[Replicate()]],
4848
# SP
49-
r"model.layers.\d+.input_layernorm.input": [[Shard(1)]],
50-
r"model.layers.\d+.input_layernorm.output": [[Shard(1)]],
51-
r"model.layers.\d+.self_attn.input": [[Replicate()]],
52-
r"model.layers.\d+.self_attn.output": {"attn_output": [Replicate()], "attn_weights": None, "past_key_value": None},
53-
r"model.layers.\d+.self_attn.o_proj.output": [[Replicate()]],
49+
r"mixtral_model.model.layers.\d+.input_layernorm.input": [[Shard(1)]],
50+
r"mixtral_model.model.layers.\d+.input_layernorm.output": [[Shard(1)]],
51+
r"mixtral_model.model.layers.\d+.self_attn.input": [[Replicate()]],
52+
r"mixtral_model.model.layers.\d+.self_attn.output": {
53+
"attn_output": [Replicate()],
54+
"attn_weights": None,
55+
"past_key_value": None,
56+
},
57+
r"mixtral_model.model.layers.\d+.self_attn.o_proj.output": [[Replicate()]],
5458
# No SP
5559
# r"model.layers.\d+.post_attention_layernorm.input": [[Replicate()]],
5660
# r"model.layers.\d+.post_attention_layernorm.output": [[Replicate()]],
5761
# SP
58-
r"model.layers.\d+.post_attention_layernorm.input": [[Shard(1)]],
59-
r"model.layers.\d+.post_attention_layernorm.output": [[Shard(1)]],
60-
r"model.layers.\d+.block_sparse_moe.input": [[Replicate()]],
61-
r"model.layers.\d+.block_sparse_moe.gate.output": [[Replicate()]],
62-
r"model.layers.\d+.block_sparse_moe.output": {"final_hidden_states": [Replicate()], "router_logits": [Replicate()]},
63-
r"model.layers.\d+.block_sparse_moe.experts.\d+.w1.input": [[Replicate()]],
64-
r"model.layers.\d+.block_sparse_moe.experts.\d+.w3.input": [[Replicate()]],
65-
r"model.layers.\d+.block_sparse_moe.experts.\d+.w2.output": [[Replicate()]],
66-
"model.norm.input": [[Replicate()]],
62+
r"mixtral_model.model.layers.\d+.post_attention_layernorm.input": [[Shard(1)]],
63+
r"mixtral_model.model.layers.\d+.post_attention_layernorm.output": [[Shard(1)]],
64+
r"mixtral_model.model.layers.\d+.block_sparse_moe.input": [[Replicate()]],
65+
r"mixtral_model.model.layers.\d+.block_sparse_moe.gate.output": [[Replicate()]],
66+
r"mixtral_model.model.layers.\d+.block_sparse_moe.output": {
67+
"final_hidden_states": [Replicate()],
68+
"router_logits": [Replicate()],
69+
},
70+
r"mixtral_model.model.layers.\d+.block_sparse_moe.experts.\d+.w1.input": [[Replicate()]],
71+
r"mixtral_model.model.layers.\d+.block_sparse_moe.experts.\d+.w3.input": [[Replicate()]],
72+
r"mixtral_model.model.layers.\d+.block_sparse_moe.experts.\d+.w2.output": [[Replicate()]],
73+
"mixtral_model.model.norm.input": [[Replicate()]],
6774
}
6875

6976
mixtral_plan = {"parameter": param_sharding_plan, "forward": fwd_resharding_plan}

0 commit comments

Comments
 (0)