We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent dd1da39 commit 80cc859Copy full SHA for 80cc859
examples/language_model/moe/dygraph/run_moe_pretrain.py
@@ -551,7 +551,7 @@ def do_train(args):
551
552
if args.gate != "naive" and args.balance_loss_weight:
553
aux_loss_list = [
554
- l.moe_mlp.gate.get_loss(clear=False)
+ l.moe_mlp.gate.get_loss(clear=False).reshape([-1])
555
for l in model.gpt.decoder.layers
556
if hasattr(l.moe_mlp, "gate")
557
]
0 commit comments