Skip to content

Commit 0ba1bda

Browse files
authored
Llama4 MoE Grouped GEMM (#2639)
* add llama4 reference layer * add llama4 reference impl * formatting
1 parent b8c9755 commit 0ba1bda

File tree

9 files changed

+1418
-372
lines changed

9 files changed

+1418
-372
lines changed

unsloth/kernels/moe/README.md

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,30 +43,43 @@ sets of tests. E.g., to run forward tests with autotune turned on: `pytest -sv
4343
- `grouped_gemm/tests/test_qwen3_moe.py`: end to end test for Qwen3 MoE block. IMPORTANT: read `tests/run_qwen3_moe_tests.sh` as well as notes in the test itself for complications when running parametrized pytest test suites and triton / autotune. TLDR: use the test script and NOT pytest to run the tests.
4444

4545
### Benchmarks
46-
- `grouped_gemm/benchmark/benchmark_fused_moe.py`: benchmarks HF `Qwen3SpareMOEBlock` against the fused implementation
46+
- `grouped_gemm/benchmark/benchmark_fused_moe.py`: benchmarks HF `Qwen3SpareMOEBlock` or `Llama4TextMoe` against the fused implementation
47+
48+
4749
Running with these flags on an `H100` to bench forward pass (run with `--help` to see all available flags):
50+
51+
For `Qwen3-30B-A3B`:
4852
```
49-
python benchmark/benchmark_fused_moe.py --mode forward --seqlen 1024 --permute_x --permute_y --autotune
53+
python benchmark/benchmark_fused_moe.py --model qwen3 --mode forward --seqlen 1024 --permute_x --permute_y --autotune
5054
```
5155

5256
For the backward bench:
5357
```
54-
python benchmark/benchmark_fused_moe.py --mode backward --seqlen 1024 --permute_x --permute_y --autotune
58+
python benchmark/benchmark_fused_moe.py --model qwen3 --mode backward --seqlen 1024 --permute_x --permute_y --autotune
5559
```
5660

57-
On my machine and env, I get speedups > 25x and 14x respectively.
61+
For `Llama-4-Scout-17B-16E`:
62+
```
63+
python benchmark/benchmark_fused_moe.py --model llama4 --autotune --mode=forward --permute_y
64+
```
65+
Ditto for backwards.
5866

5967
### Notes
6068
- Tested and benched on `H100`, though should run on Ampere and possibly even earlier gpu generations though the autotuning configs will need to be adjusted.
6169
- The env I used to develop the kernel was `pytorch 2.7/2.8` and `pytorch-triton 3.3`.
6270
- The kernels can be run either as autotuned (see `autotuning.py`) or with manually specified config (see `tuning.py`). Recommended to run using autotuner since the MoE block requires 2 configs for the forward (2 grouped gemms) and 4 for the backwards (dX and dW per grouped gemm, 2 grouped gemms).
6371
- Running with autotuning turned off with the default manual kernel config will result is **highly** sub-optimal performance as it is only meant for testing / debugging purposes.
6472
- I've tried to strike a balance between compilation time and autotuning search space -- can probably squeeze even more performance for specific workloads.
73+
- The Llama4 reference layer is still highly under-optimized as there are many low-hanging opportunities for further speedups around routing and shared expert calculation.
6574

6675
TODO:
6776
- TMA store: implemented but not enabled currently due to non-determinism arising from triton pipelining bug.
6877
- Warp specialization: Hopper support for WS not yet enabled on triton 3.3x branch which ships with latest pytorch 2.7.
6978
- Additional optimizations:
7079
- Fused / optimized implementations of routing, token sorting, etc.
7180
- Better software pipelining within grouped gemm
72-
- Threadblock swizzling for better L2 caching
81+
- Threadblock swizzling for better L2 caching
82+
- Llama4
83+
- Fused gather / topk weight merging
84+
- Custom topk, gather indices kernel
85+
- Shared expert fusion with experts calculation

0 commit comments

Comments
 (0)