You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: unsloth/kernels/moe/README.md
+18-5Lines changed: 18 additions & 5 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -43,30 +43,43 @@ sets of tests. E.g., to run forward tests with autotune turned on: `pytest -sv
43
43
-`grouped_gemm/tests/test_qwen3_moe.py`: end to end test for Qwen3 MoE block. IMPORTANT: read `tests/run_qwen3_moe_tests.sh` as well as notes in the test itself for complications when running parametrized pytest test suites and triton / autotune. TLDR: use the test script and NOT pytest to run the tests.
44
44
45
45
### Benchmarks
46
-
-`grouped_gemm/benchmark/benchmark_fused_moe.py`: benchmarks HF `Qwen3SpareMOEBlock` against the fused implementation
46
+
-`grouped_gemm/benchmark/benchmark_fused_moe.py`: benchmarks HF `Qwen3SpareMOEBlock` or `Llama4TextMoe` against the fused implementation
47
+
48
+
47
49
Running with these flags on an `H100` to bench forward pass (run with `--help` to see all available flags):
- Tested and benched on `H100`, though should run on Ampere and possibly even earlier gpu generations though the autotuning configs will need to be adjusted.
61
69
- The env I used to develop the kernel was `pytorch 2.7/2.8` and `pytorch-triton 3.3`.
62
70
- The kernels can be run either as autotuned (see `autotuning.py`) or with manually specified config (see `tuning.py`). Recommended to run using autotuner since the MoE block requires 2 configs for the forward (2 grouped gemms) and 4 for the backwards (dX and dW per grouped gemm, 2 grouped gemms).
63
71
- Running with autotuning turned off with the default manual kernel config will result is **highly** sub-optimal performance as it is only meant for testing / debugging purposes.
64
72
- I've tried to strike a balance between compilation time and autotuning search space -- can probably squeeze even more performance for specific workloads.
73
+
- The Llama4 reference layer is still highly under-optimized as there are many low-hanging opportunities for further speedups around routing and shared expert calculation.
65
74
66
75
TODO:
67
76
- TMA store: implemented but not enabled currently due to non-determinism arising from triton pipelining bug.
68
77
- Warp specialization: Hopper support for WS not yet enabled on triton 3.3x branch which ships with latest pytorch 2.7.
69
78
- Additional optimizations:
70
79
- Fused / optimized implementations of routing, token sorting, etc.
0 commit comments