Skip to content
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
8a62c51
Introduce Mooncake Backend and Mooncake EP
UNIDY2002 Sep 11, 2025
1078fb2
tiny fix mooncake pr (#12)
HanHan009527 Sep 11, 2025
74843fb
Merge branch 'main' into mooncake-pr
zhyncs Sep 14, 2025
fa7d9d2
Fix for more readable code
UNIDY2002 Sep 15, 2025
3ebf53b
Merge branch 'main' into mooncake-pr
UNIDY2002 Sep 15, 2025
6a795a4
lint
HanHan009527 Sep 15, 2025
96b2c6b
API: change broken_ranks to active_ranks
UNIDY2002 Sep 16, 2025
b7ba04d
Fix issues according to review
UNIDY2002 Sep 25, 2025
f6ddaeb
Fix an issue
UNIDY2002 Sep 25, 2025
867054b
Merge branch 'main' into mooncake-pr
UNIDY2002 Sep 25, 2025
4659ef6
Bug-fix
UNIDY2002 Sep 25, 2025
ad759f7
Bug-fix
UNIDY2002 Sep 25, 2025
719775a
Change arg `--enable-mooncake-dist-backend` to `--elastic-ep-backend …
UNIDY2002 Sep 30, 2025
ec01f35
Fix format
UNIDY2002 Sep 30, 2025
829e466
Merge branch 'main' into mooncake-pr
UNIDY2002 Oct 8, 2025
3c82863
Add ep/test_mooncake_ep_small.py to test suite, and a couple of fixes:
UNIDY2002 Oct 8, 2025
d243eee
Merge branch 'main' into mooncake-pr
UNIDY2002 Oct 8, 2025
3bc4150
Remove '--mooncake-ib-device' from tests
UNIDY2002 Oct 9, 2025
da0b6e6
Merge branch 'main' into mooncake-pr
UNIDY2002 Oct 9, 2025
f0bec7f
Revert "Remove '--mooncake-ib-device' from tests"
UNIDY2002 Oct 10, 2025
21dacf4
Merge branch 'main' into mooncake-pr
UNIDY2002 Oct 10, 2025
d08d368
Merge branch 'main' into mooncake-pr
UNIDY2002 Oct 10, 2025
393307f
Merge branch 'main' into mooncake-pr
ShangmingCai Oct 10, 2025
4b02c6e
upd
ShangmingCai Oct 10, 2025
8c87b9e
Update run_suite.py
ShangmingCai Oct 10, 2025
752191a
Install Mooncake+EP in `ci_install_deepep.sh` script
UNIDY2002 Oct 10, 2025
41d9df7
fix ci_install_deepep.sh
ShangmingCai Oct 10, 2025
de0e23d
Merge branch 'main' into mooncake-pr
UNIDY2002 Oct 11, 2025
a3c359b
Merge branch 'main' into mooncake-pr
UNIDY2002 Oct 13, 2025
4d9255f
Fix format
UNIDY2002 Oct 13, 2025
dfa76bb
Merge branch 'main' into mooncake-pr
ShangmingCai Oct 13, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/pr-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,8 @@ jobs:
- name: Install dependencies
run: |
CUSTOM_BUILD_SGL_KERNEL=${{needs.check-changes.outputs.sgl_kernel}} bash scripts/ci/ci_install_dependency.sh
curl -L https://cloud.tsinghua.edu.cn/f/c22ec766545e48bf99e8/?dl=1 -o mooncake_transfer_engine-0.3.6.post1+ep-cp310-cp310-manylinux_2_17_x86_64.manylinux_2_35_x86_64.whl
UV_SYSTEM_PYTHON=true uv pip install mooncake_transfer_engine-0.3.6.post1+ep-cp310-cp310-manylinux_2_17_x86_64.manylinux_2_35_x86_64.whl

- name: Run test
timeout-minutes: 20
Expand Down
4 changes: 3 additions & 1 deletion docs/advanced_features/server_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| Arguments | Description | Defaults |
|-----------|-------------|----------|
| `--device` | The device to use ('cuda', 'xpu', 'hpu', 'npu', 'cpu'). Defaults to auto-detection if not specified. | None |
| `--elastic-ep-backend` | Select the collective communication backend for elastic EP. Currently supports 'mooncake'. | None |
| `--mooncake-ib-device` | The InfiniBand devices for Mooncake Backend, accepts multiple comma-separated devices. Default is None, which triggers automatic device detection when Mooncake Backend is enabled. | None |
| `--tp-size` | The tensor parallelism size. | 1 |
| `--pp-size` | The pipeline parallelism size. | 1 |
| `--pp-max-micro-batch-size` | The maximum micro batch size in pipeline parallelism. | None |
Expand Down Expand Up @@ -245,7 +247,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| Arguments | Description | Defaults |
|-----------|-------------|----------|
| `--ep-size` | The expert parallelism size. | 1 |
| `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism. | none |
| `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism, could be `deepep` or `mooncake`. | none |
| `--moe-runner-backend` | Select the runner backend for MoE. | auto |
| `--deepep-mode` | Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch. | auto |
| `--ep-num-redundant-experts` | Allocate this number of redundant experts in expert parallel. | 0 |
Expand Down
22 changes: 19 additions & 3 deletions python/sglang/srt/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
direct_register_custom_op,
get_bool_env_var,
get_int_env_var,
get_local_ip_auto,
is_cpu,
is_cuda_alike,
is_hip,
Expand Down Expand Up @@ -256,9 +257,13 @@ def __init__(
device_group = torch.distributed.new_group(
ranks, backend=torch_distributed_backend
)
# a group with `gloo` backend, to allow direct coordination between
# processes through the CPU.
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
# a cpu_group to allow direct coordination between processes through
# the CPU. The backend is chosen based on `torch_distributed_backend`
if "mooncake" in torch_distributed_backend:
backend = "mooncake-cpu"
else:
backend = "gloo"
cpu_group = torch.distributed.new_group(ranks, backend=backend)
if self.rank in ranks:
self.ranks = ranks
self.world_size = len(ranks)
Expand Down Expand Up @@ -1394,6 +1399,17 @@ def init_distributed_environment(
distributed_init_method,
backend,
)
if "mooncake" in backend:
try:
from mooncake import ep as mooncake_ep
except ImportError as e:
raise ImportError(
"Please install mooncake by following the instructions at "
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
"to run SGLang with Mooncake Backend."
) from e
mooncake_ep.set_host_ip(get_local_ip_auto())

if not torch.distributed.is_initialized():
assert distributed_init_method is not None, (
"distributed_init_method must be provided when initializing "
Expand Down
3 changes: 2 additions & 1 deletion python/sglang/srt/layers/moe/ep_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
class DeepEPMoE(FusedMoE):
"""
MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
Mooncake EP shares the same class, as they expose the same interface.
"""

_has_printed = False
Expand Down Expand Up @@ -686,7 +687,7 @@ def _forward_ll(dispatch_output: DeepEPLLOutput):


def get_moe_impl_class(quant_config: Optional[QuantizationConfig]):
if get_moe_a2a_backend().is_deepep():
if get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake():
return DeepEPMoE

# NEW: Direct FP4 detection (bypasses EP requirements)
Expand Down
8 changes: 8 additions & 0 deletions python/sglang/srt/layers/moe/token_dispatcher/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
DeepEPNormalCombineInput,
DeepEPNormalOutput,
)
from sglang.srt.layers.moe.token_dispatcher.mooncake import (
MooncakeCombineInput,
MooncakeDispatchOutput,
MooncakeEPDispatcher,
)
from sglang.srt.layers.moe.token_dispatcher.standard import (
StandardCombineInput,
StandardDispatchOutput,
Expand All @@ -30,6 +35,9 @@
"DispatchOutput",
"DispatchOutputFormat",
"DispatchOutputChecker",
"MooncakeCombineInput",
"MooncakeDispatchOutput",
"MooncakeEPDispatcher",
"StandardDispatchOutput",
"StandardCombineInput",
"DeepEPConfig",
Expand Down
Loading
Loading