File tree Expand file tree Collapse file tree 1 file changed +8
-0
lines changed
test/prototype/moe_training Expand file tree Collapse file tree 1 file changed +8
-0
lines changed Original file line number Diff line number Diff line change 1919
2020# this test requires torchtitan
2121try :
22+ from torchtitan .experiments .llama4 .infra .expert_parallel import (
23+ set_token_group_alignment_size_m ,
24+ )
2225 from torchtitan .experiments .llama4 .model .args import TransformerModelArgs
2326 from torchtitan .experiments .llama4 .model .moe import MoE
2427except ImportError :
3639)
3740@pytest .mark .parametrize ("compile" , [False , True ])
3841def test_moe_float8_training (target_fqns : list [str ], compile : bool ):
42+ # Set token group alignment size to 16. This is required so that
43+ # each logically distinct gemm in the grouped gemm `grad_weight = grad_output_t @ input`
44+ # has the contraction dim be divisible by 16. 16 byte alignment is required
45+ # for the slowest moving dim (stride 1), so 16 bytes / 1 byte per element in fp8 = 16 elements.
46+ set_token_group_alignment_size_m (16 )
3947 model_args = TransformerModelArgs (
4048 moe_enabled = True ,
4149 num_experts = 8 ,
You can’t perform that action at this time.
0 commit comments