Skip to content

Conversation

jianyingzhu
Copy link
Contributor

@jianyingzhu jianyingzhu commented Sep 3, 2025

Motivation

This pull request adds support for the SM120 architecture in the Cutlass FP8 blockwise GEMM kernel, enabling improved performance on the latest hardware.

Modifications

Introduced an SM120-specific branch in fp8_blockwise_gemm_kernel.cu to leverage new features and optimizations of Cutlass.

Accuracy Tests

image

Benchmarking and Profiling

Benchmarking on the Qwen3-32B model shows a 55% throughput increase compared to the Triton FP8 blockwise GEMM kernel, from 1585 tok/s to 2473 tok/s.

MLP up of cutlass kernel of: 30ms
image
MLP up of Triton kernel: 60ms
Pasted Graphic 4

Checklist

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @jianyingzhu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces support for the sm120 GPU architecture in the fp8 blockwise GEMM kernel, leveraging CUTLASS. The primary goal is to enhance computational efficiency and throughput for specific hardware configurations, leading to significant performance gains in operations like those found in large language models.

Highlights

  • sm120 Architecture Support: Implemented support for the sm120 architecture within the CUTLASS fp8 blockwise GEMM kernel.
  • Performance Improvement: Achieved a 55% throughput increase on Qwen-32B when compared to the Triton fp8 blockwise GEMM kernel.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for the sm120 architecture to the FP8 blockwise GEMM kernel. The implementation largely follows the existing structure for other architectures. My review focuses on potential correctness and performance improvements. I've identified a likely incorrect scheduler configuration for the cooperative kernel, suggested adding problem-size-dependent tile shape dispatching to optimize performance, and pointed out significant code duplication that could be refactored to improve maintainability.

Comment on lines +294 to +298
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, // Indicates ProblemShape
CollectiveMainloop,
CollectiveEpilogue,
void>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The TileScheduler template argument for cutlass::gemm::kernel::GemmUniversal is specified as void, which defaults to cutlass::gemm::PersistentScheduler. However, the comment for CollectiveMainloop on line 290 states that KernelScheduleAuto defaults to a cooperative kernel schedule. PersistentScheduler is generally used for non-cooperative kernels. Using it with a cooperative mainloop might lead to incorrect behavior or suboptimal performance. For a cooperative kernel, you should use cutlass::gemm::CooperativeScheduler.

      cutlass::gemm::CooperativeScheduler>;

Comment on lines +198 to +351
ScaleGranularityK,
cute::UMMA::Major::MN,
cute::UMMA::Major::K>;
// FP8 Block-wise scaling configuration
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand

using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
PerSmTileShape,
ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator,
ElementAccumulator,
ElementC,
LayoutCTag,
AlignmentC,
ElementD,
LayoutDTag,
AlignmentD,
cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy
>::CollectiveOp;

using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
ElementA,
cute::tuple<LayoutATag, LayoutSFA>,
AlignmentA,
ElementB,
cute::tuple<LayoutBTag, LayoutSFB>,
AlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto defaults to cooperative kernel
// schedule
>::CollectiveOp;

using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, // Indicates ProblemShape
CollectiveMainloop,
CollectiveEpilogue,
void>;

using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

Gemm gemm_op;

int m = a.size(0);
int k = a.size(1);
int n = b.size(1);

auto a_ptr = static_cast<ElementA*>(a.data_ptr());
auto b_ptr = static_cast<ElementB*>(b.data_ptr());
auto c_ptr = static_cast<ElementD*>(out.data_ptr());

auto scales_a_ptr = static_cast<ElementBlockScale*>(scales_a.data_ptr());
auto scales_b_ptr = static_cast<ElementBlockScale*>(scales_b.data_ptr());

using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideD = typename Gemm::GemmKernel::StrideD;
using StrideC = typename Gemm::GemmKernel::StrideD;

StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
StrideC stride_c = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1));
LayoutSFA layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1));
LayoutSFB layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));

typename GemmKernel::MainloopArguments mainloop_args{
a_ptr, stride_a, b_ptr, stride_b, scales_a_ptr, layout_SFA, scales_b_ptr, layout_SFB};

typename GemmKernel::EpilogueArguments epilogue_args{{}, c_ptr, stride_c, c_ptr, stride_c};
epilogue_args.thread.alpha = 1.0f;

typename Gemm::Arguments args = {
cutlass::gemm::GemmUniversalMode::kGemm,
{m, n, k, 1},
mainloop_args,
epilogue_args,
};

auto can_implement = gemm_op.can_implement(args);
TORCH_CHECK(can_implement == cutlass::Status::kSuccess, cutlassGetStatusString(can_implement))

size_t workspace_size = gemm_op.get_workspace_size(args);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);

auto init_status = gemm_op.initialize(args, workspace.get());
TORCH_CHECK(init_status == cutlass::Status::kSuccess, cutlassGetStatusString(init_status));

auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
auto status = gemm_op.run(stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, cutlassGetStatusString(status))
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The new function launch_sm120_fp8_blockwise_scaled_mm is very similar to the existing launch_sm100_fp8_blockwise_scaled_mm. There is a large amount of duplicated code for type definitions, argument setup, and kernel launching logic. This makes the code harder to maintain. Consider refactoring the common code into a template function that can be specialized for different architectures (Sm100, Sm120, etc.) by passing the architecture-specific types and configurations as template parameters.

Comment on lines +353 to +366
template <typename OutType>
void sm120_fp8_blockwise_dispatch_shape(
torch::Tensor& out,
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b) {
using MmaTileShape = Shape<_128, _128, _128>;
using PerSmTileShape = Shape<_128, _128, _128>;
using EpilogueTileShape = Shape<_128, _64>;
using ScalesPerTile = Shape<_128, _1, _1>;
launch_sm120_fp8_blockwise_scaled_mm<OutType, MmaTileShape, PerSmTileShape, EpilogueTileShape, ScalesPerTile>(
out, a, b, scales_a, scales_b);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The sm120_fp8_blockwise_dispatch_shape function uses a fixed set of tile shapes. The corresponding function for sm100, sm100_fp8_blockwise_dispatch_shape, has logic to select different tile shapes based on the M dimension of the input tensor a (specifically, if (a.size(0) <= 128)). This is a common optimization to improve performance for different problem sizes. You should consider adding similar dispatch logic here for sm120 to select more optimal tile shapes for smaller M dimensions.

template <typename OutType>
void sm120_fp8_blockwise_dispatch_shape(
    torch::Tensor& out,
    const torch::Tensor& a,
    const torch::Tensor& b,
    const torch::Tensor& scales_a,
    const torch::Tensor& scales_b) {
  if (a.size(0) <= 128) {
    using MmaTileShape = Shape<_64, _128, _128>;
    using PerSmTileShape = Shape<_64, _128, _128>;
    using EpilogueTileShape = Shape<_64, _64>;
    using ScalesPerTile = Shape<_64, _1, _1>;
    launch_sm120_fp8_blockwise_scaled_mm<OutType, MmaTileShape, PerSmTileShape, EpilogueTileShape, ScalesPerTile>(
        out, a, b, scales_a, scales_b);
  } else {
    using MmaTileShape = Shape<_128, _128, _128>;
    using PerSmTileShape = Shape<_128, _128, _128>;
    using EpilogueTileShape = Shape<_128, _64>;
    using ScalesPerTile = Shape<_128, _1, _1>;
    launch_sm120_fp8_blockwise_scaled_mm<OutType, MmaTileShape, PerSmTileShape, EpilogueTileShape, ScalesPerTile>(
        out, a, b, scales_a, scales_b);
  }
}

@zhyncs
Copy link
Member

zhyncs commented Sep 4, 2025

@jianyingzhu
Copy link
Contributor Author

@jianyingzhu https://github.com/sgl-project/sglang/actions/runs/17445169585/job/49551658420 ut failed

I only modified one file (fp8_blockwise_gemm_kernel.cu), but many kernels that I didn’t change are failing in the unit tests, such as _test_flash_attn_varlen_output and test_cutlass_w4a8_moe_mm.py. I also noticed that others’ PRs on sm120 are encountering the same issues in their CI runs, for example: https://github.com/sgl-project/sglang/actions/runs/17450401148/job/49558258257?pr=9992

@zhyncs zhyncs self-assigned this Sep 6, 2025
@zhyncs zhyncs merged commit dd1e268 into sgl-project:main Sep 7, 2025
59 of 63 checks passed
@voipmonitor
Copy link
Contributor

@jianyingzhu I suppose that this PR does not support GLM-4.5-Air-FP8 ? The latest compiled sgl-kernel and sglang still fails with:

File "/sgl-workspace/sglang.git/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py", line 164, in apply_weights
return apply_fp8_linear(
^^^^^^^^^^^^^^^^^
File "/sgl-workspace/sglang.git/python/sglang/srt/layers/quantization/fp8_utils.py", line 596, in apply_fp8_linear
output = fp8_scaled_mm(
^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/sgl_kernel/gemm.py", line 36, in fp8_scaled_mm
return torch.ops.sgl_kernel.fp8_scaled_mm.default(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 829, in call
return self._op(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected status == cutlass::Status::kSuccess to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.)

(python -m sglang.launch_server --model /mnt/GLM-4.5-Air-FP8/ --tp 2 --attention-backend flashinfer --host 0.0.0.0 --port 5001 --mem-fraction-static 0.95 --context-length 128000)

@jianyingzhu
Copy link
Contributor Author

@jianyingzhu I suppose that this PR does not support GLM-4.5-Air-FP8 ? The latest compiled sgl-kernel and sglang still fails with:

File "/sgl-workspace/sglang.git/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py", line 164, in apply_weights return apply_fp8_linear( ^^^^^^^^^^^^^^^^^ File "/sgl-workspace/sglang.git/python/sglang/srt/layers/quantization/fp8_utils.py", line 596, in apply_fp8_linear output = fp8_scaled_mm( ^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/sgl_kernel/gemm.py", line 36, in fp8_scaled_mm return torch.ops.sgl_kernel.fp8_scaled_mm.default( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 829, in call return self._op(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: Expected status == cutlass::Status::kSuccess to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.)

(python -m sglang.launch_server --model /mnt/GLM-4.5-Air-FP8/ --tp 2 --attention-backend flashinfer --host 0.0.0.0 --port 5001 --mem-fraction-static 0.95 --context-length 128000)

Yes, this PR only supports fp8_blockwise_scaled_mm on sm120, and does not currently support fp8_scaled_mm on sm120. The GLM-4.5-Air-FP8 model calls fp8_scaled_mm. The fp8_scaled_mm may correspond to this PR: #9403. Thank you!

@voipmonitor
Copy link
Contributor

01 --mem-fraction-static 0.95 --context-length 128000)

Yes, this PR only supports fp8_blockwise_scaled_mm on sm120, and does not currently support fp8_scaled_mm on sm120. The GLM-4.5-Air-FP8 model calls fp8_scaled_mm. The fp8_scaled_mm may correspond to this PR: #9403. Thank you!

Thank you - what is exact configuratin options to run fp8 blockwise gemm (qwen3) models with this PR? (I'm trying it - looks good but I want to be sure I used all relevant parameters

@voipmonitor
Copy link
Contributor

also - there is this PR in flashinfer - flashinfer-ai/flashinfer#1610 (cutlass fp8 gemm bringup for SM120 & SM121) - is this related to this PR in any way or its the flashinfer FP8 sm120 support with different kernels?

@jianyingzhu
Copy link
Contributor Author

01 --mem-fraction-static 0.95 --context-length 128000)

Yes, this PR only supports fp8_blockwise_scaled_mm on sm120, and does not currently support fp8_scaled_mm on sm120. The GLM-4.5-Air-FP8 model calls fp8_scaled_mm. The fp8_scaled_mm may correspond to this PR: #9403. Thank you!

Thank you - what is exact configuratin options to run fp8 blockwise gemm (qwen3) models with this PR? (I'm trying it - looks good but I want to be sure I used all relevant parameters

My model weights are Qwen3-32B-FP8, with the quantization config set to blockwise.
Since deepgemm currently does not support sm120, you need to set the environment variable:
export SGL_ENABLE_JIT_DEEPGEMM=0
Enable the environment variable for cutlass fp8:
export SGLANG_SUPPORT_CUTLASS_BLOCK_FP8=1
You also need to ensure that the CUDA version is >= 12.9, because of this line in the code:


if cuda_version >= (12, 0) and sm_version >= 90:

Therefore, I am using the docker image lmsysorg/sglang:b200-cu129.

@voipmonitor
Copy link
Contributor

SGLANG_SUPPORT_CUTLASS_BLOCK_FP8

SGL_ENABLE_JIT_DEEPGEMM=False python -m sglang.launch_server --model /mnt/Qwen3-30B-A3B-Instruct-2507-FP8/ --tp 2 --attention-backend flashinfer --host 0.0.0.0 --port 5000 --mem-fraction-static 0.95 --context-length 128000 --enable-metrics

185tokens/sec

vs

SGLANG_SUPPORT_CUTLASS_BLOCK_FP8=1 SGL_ENABLE_JIT_DEEPGEMM=False python -m sglang.launch_server --model /mnt/Qwen3-30B-A3B-Instruct-2507-FP8/ --tp 2 --attention-backend flashinfer --host 0.0.0.0 --port 5000 --mem-fraction-static 0.95 --context-length 128000 --enable-metrics

155tokens/sec

for a single inference - is it expected that the SGLANG_SUPPORT_CUTLASS_BLOCK_FP8 is that slower?

I have compiled this PR inside the lmsysorg/sglang:b200-cu129 docker (with the latest flashinfer 0.3.1 )

MahmoudAshraf97 pushed a commit to MahmoudAshraf97/sglang that referenced this pull request Sep 8, 2025
@jianyingzhu
Copy link
Contributor Author

jianyingzhu commented Sep 9, 2025

SGLANG_SUPPORT_CUTLASS_BLOCK_FP8

SGL_ENABLE_JIT_DEEPGEMM=False python -m sglang.launch_server --model /mnt/Qwen3-30B-A3B-Instruct-2507-FP8/ --tp 2 --attention-backend flashinfer --host 0.0.0.0 --port 5000 --mem-fraction-static 0.95 --context-length 128000 --enable-metrics

185tokens/sec

vs

SGLANG_SUPPORT_CUTLASS_BLOCK_FP8=1 SGL_ENABLE_JIT_DEEPGEMM=False python -m sglang.launch_server --model /mnt/Qwen3-30B-A3B-Instruct-2507-FP8/ --tp 2 --attention-backend flashinfer --host 0.0.0.0 --port 5000 --mem-fraction-static 0.95 --context-length 128000 --enable-metrics

155tokens/sec

for a single inference - is it expected that the SGLANG_SUPPORT_CUTLASS_BLOCK_FP8 is that slower?

I have compiled this PR inside the lmsysorg/sglang:b200-cu129 docker (with the latest flashinfer 0.3.1 )

I tried the same command on my machine and test the performance, on my machine the throughput of SGLANG_SUPPORT_CUTLASS_BLOCK_FP8 is doubled, with 2010.77 tok/s vs 4120.86 tok/s (SGLANG_SUPPORT_CUTLASS_BLOCK_FP8=1), seen as follows:

SGL_ENABLE_JIT_DEEPGEMM=False python -m sglang.launch_server --model /workspace/qwen3-30b-fp8/ --tp 2 --attention-backend flashinfer --host 0.0.0.0 --port 5000 --mem-fraction-static 0.95 --context-length 128000 --enable-metrics
 
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 128 --random-input 1024 --random-output 1 --random-range-ratio 1 --host 0.0.0.0 --port 5000 --max-concurrency 128 --profile

 

============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Max request concurrency:                 128
Successful requests:                     128
Benchmark duration (s):                  65.19
Total input tokens:                      131072
Total generated tokens:                  128
Total generated tokens (retokenized):    128
Request throughput (req/s):              1.96
Input token throughput (tok/s):          2010.77
Output token throughput (tok/s):         1.96
Total token throughput (tok/s):          2012.73
Concurrency:                             33.81
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   17220.31
Median E2E Latency (ms):                 17181.33
---------------Time to First Token----------------
Mean TTFT (ms):                          17220.30
Median TTFT (ms):                        17181.33
P99 TTFT (ms):                           19443.55
---------------Inter-Token Latency----------------
Mean ITL (ms):                           0.00
Median ITL (ms):                         0.00
P95 ITL (ms):                            0.00
P99 ITL (ms):                            0.00
Max ITL (ms):                            0.00

 

SGLANG_SUPPORT_CUTLASS_BLOCK_FP8=1 SGL_ENABLE_JIT_DEEPGEMM=False python -m sglang.launch_server --model /workspace/qwen3-30b-fp8/ --tp 2 --attention-backend flashinfer --host 0.0.0.0 --port 5000 --mem-fraction-static 0.95 --context-length 128000 --enable-metrics
 
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 128 --random-input 1024 --random-output 1 --random-range-ratio 1 --host 0.0.0.0 --port 5000 --max-concurrency 128 --profile

 

============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Max request concurrency:                 128
Successful requests:                     128
Benchmark duration (s):                  31.81
Total input tokens:                      131072
Total generated tokens:                  128
Total generated tokens (retokenized):    128
Request throughput (req/s):              4.02
Input token throughput (tok/s):          4120.86
Output token throughput (tok/s):         4.02
Total token throughput (tok/s):          4124.89
Concurrency:                             8.70
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   2161.35
Median E2E Latency (ms):                 2170.09
---------------Time to First Token----------------
Mean TTFT (ms):                          2161.34
Median TTFT (ms):                        2170.09
P99 TTFT (ms):                           3593.48
---------------Inter-Token Latency----------------
Mean ITL (ms):                           0.00
Median ITL (ms):                         0.00
P95 ITL (ms):                            0.00
P99 ITL (ms):                            0.00
Max ITL (ms):                            0.00

 

@voipmonitor
Copy link
Contributor

voipmonitor commented Sep 9, 2025

@jianyingzhu I'm getting 4374 input token throughput without the SGLANG_SUPPORT_CUTLASS_BLOCK_FP8 - I cannot reproduce your low numbers

Can you please try my docker image which I have just pushed: docker.io/voipmonitor/b200-cu129-pr9969

Run the command:

SGL_ENABLE_JIT_DEEPGEMM=False python -m sglang.launch_server --model /mnt/Qwen3-30B-A3B-Instruct-2507-FP8/ --tp 2 --attention-backend flashinfer --host 0.0.0.0 --port 5000 --mem-fraction-static 0.95 --context-length 128000 --enable-metrics

Inside the docker I have optimised truton json files for my NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition - this should be same as your 50XX as the chip is same as rtx 6000 pro - there will be only different name - just copy two files N=2560,K=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json and /sgl-workspace/sglang.git/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json to your json file (you will see what file you need once you run the command)

This is my output:

[2025-09-09 14:34:47 TP0] Using configuration from /sgl-workspace/sglang.git/python/sglang/srt/layers/quantization/configs/N=2048,K=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json for W8A8 Block FP8 kernel.
[2025-09-09 14:34:47 TP1] Using MoE kernel config from /sgl-workspace/sglang.git/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json.
[2025-09-09 14:34:47 TP0] Using MoE kernel config from /sgl-workspace/sglang.git/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128, 128].json.``` 


And this is result using ```python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 128 --random-input 1024 --random-output 1 --random-range-ratio 1 --host 0.0.0.0 --port 5000 --max-concurrency 128 --profile
============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Max request concurrency:                 128
Successful requests:                     128
Benchmark duration (s):                  29.99
Total input tokens:                      131072
Total generated tokens:                  128
Total generated tokens (retokenized):    128
Request throughput (req/s):              4.27
Input token throughput (tok/s):          4370.08
Output token throughput (tok/s):         4.27
Total token throughput (tok/s):          4374.35
Concurrency:                             18.81
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   4407.64
Median E2E Latency (ms):                 4282.35
---------------Time to First Token----------------
Mean TTFT (ms):                          4407.63
Median TTFT (ms):                        4282.35
P99 TTFT (ms):                           7811.77
---------------Inter-Token Latency----------------
Mean ITL (ms):                           0.00
Median ITL (ms):                         0.00
P95 ITL (ms):                            0.00
P99 ITL (ms):                            0.00
Max ITL (ms):                            0.00
==================================================

this is for

SGLANG_SUPPORT_CUTLASS_BLOCK_FP8=1 SGL_ENABLE_JIT_DEEPGEMM=False python -m sglang.launch_server --model /mnt/Qwen3-30B-A3B-Instruct-2507-FP8/ --tp 2 --attention-backend flashinfer --host 0.0.0.0 --port 5000 --mem-fraction-static 0.95 --context-length 128000 --enable-metrics

Backend:                                 sglang
Traffic request rate:                    inf
Max request concurrency:                 128
Successful requests:                     128
Benchmark duration (s):                  30.83
Total input tokens:                      131072
Total generated tokens:                  128
Total generated tokens (retokenized):    128
Request throughput (req/s):              4.15
Input token throughput (tok/s):          4251.86
Output token throughput (tok/s):         4.15
Total token throughput (tok/s):          4256.01
Concurrency:                             17.02
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   4099.54
Median E2E Latency (ms):                 4060.97
---------------Time to First Token----------------
Mean TTFT (ms):                          4099.54
Median TTFT (ms):                        4060.96
P99 TTFT (ms):                           7257.00
---------------Inter-Token Latency----------------
Mean ITL (ms):                           0.00
Median ITL (ms):                         0.00
P95 ITL (ms):                            0.00                                                                                                                                                                                                                                                                                               P99 ITL (ms):                            0.00
Max ITL (ms):                            0.00
==================================================```

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants