Skip to content

[NVIDIA] Support Cutlass w8a8 FP8 for Blackwell Geforce GPUs (sm120) #17280

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 2, 2025

Conversation

kaln27
Copy link
Contributor

@kaln27 kaln27 commented Apr 28, 2025

Add Cutlass w8a8 support for Blackwell Geforce sm120.
Currently when use sm100 kernel will cause an internal error. I don't know the reason.

Work well on RTX 5070Ti with Qwen2.5-VL-7B-Instruct-FP8-Dynamic model which quantized use llm-compressor.

FIX #16515

@kaln27 kaln27 requested a review from tlrmchlsmth as a code owner April 28, 2025 03:25
@mergify mergify bot added the ci/build label Apr 28, 2025
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@kaln27 kaln27 changed the title [NVIDIA] Support Cutlass w8a8 for Blackwell Geforce GPUs (sm120) (#16… [NVIDIA] Support Cutlass w8a8 for Blackwell Geforce GPUs (sm120) Apr 28, 2025
Copy link

mergify bot commented Jun 13, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @kaln27.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jun 13, 2025
@mergify mergify bot added qwen Related to Qwen models and removed needs-rebase labels Jun 13, 2025
Copy link

mergify bot commented Jun 18, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @kaln27.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jun 18, 2025
@voipmonitor
Copy link

I have verified that this PR works on FP8 models and it has the same speed as the TRT-LLM FP8

@mergify mergify bot removed the needs-rebase label Jun 19, 2025
@cyril23
Copy link

cyril23 commented Jun 27, 2025

After merging vllm-project/vllm today's main into https://github.com/kaln27/vllm/tree/main (I did it on my current https://github.com/cyril23/vllm/tree/main/), I've build it via

DOCKER_BUILDKIT=1 sudo docker build \
  --build-arg max_jobs=8   --build-arg nvcc_threads=1 \
  --build-arg USE_SCCACHE=1   --build-arg SCCACHE_S3_NO_CREDENTIALS=1 \
  --build-arg GIT_REPO_CHECK=0 \
  --build-arg CUDA_VERSION=12.8.1 \
  --tag wurstdeploy/vllm:kaln27updated20250627 \
  --target vllm-openai \
  --progress plain \
  -f docker/Dockerfile .

The pytorch wheel size is still < 400 MB although I've built with default settings i.e. torch_cuda_arch_list='7.0 7.5 8.0 8.9 9.0 10.0 12.0'. Everything looks fine.

You can try it out on Docker Hub: based 27th June 2025, after cyril23#2)

  • wurstdeploy/vllm:kaln27updated20250627 (built with the default torch_cuda_arch_list='7.0 7.5 8.0 8.9 9.0 10.0 12.0')
  • wurstdeploy/vllm:kaln27updated-sm120only-20250627 (built using --build-arg torch_cuda_arch_list='12.0')

Strangely performance is not as good as some older builds from 19th May 2025, after cyril23#1:

  • wurstdeploy/vllm:myvllm-kaln27-updated (built using --build-arg torch_cuda_arch_list='12.0 12.1')
# FP8:
sudo docker run --network host -e HF_TOKEN=$HF_TOKEN -v ~/inference-benchmarker-results:/opt/inference-benchmarker/results inference_benchmarker inference-benchmarker --no-console --url http://localhost:8000/v1 --max-vus 800 --duration 120s --warmup 30s --benchmark-kind rate --rates 100.0 --prompt-options "num_tokens=200,max_tokens=220,min_tokens=180,variance=10" --decode-options "num_tokens=200,max_tokens=220,min_tokens=180,variance=10" --model-name "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8" --tokenizer-name "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"
┌─────────────────┬────────────────────────────────────────────────────────────────┐
│ Parameter       │ Value                                                          │
├─────────────────┼────────────────────────────────────────────────────────────────┤
│ Max VUs         │ 800                                                            │
│ Duration        │ 120                                                            │
│ Warmup Duration │ 30                                                             │
│ Benchmark Kind  │ Rate                                                           │
│ Rates           │ [100.0]                                                        │
│ Num Rates       │ 10                                                             │
│ Prompt Options  │ num_tokens=Some(200),min_tokens=180,max_tokens=220,variance=10 │
│ Decode Options  │ num_tokens=Some(200),min_tokens=180,max_tokens=220,variance=10 │
│ Tokenizer       │ RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8                        │
│ Extra Metadata  │ N/A                                                            │
└─────────────────┴────────────────────────────────────────────────────────────────┘
# OLD build, based 19th May 2025 (after https://github.com/cyril23/vllm/pull/1)
sudo docker run --runtime nvidia --gpus all -v ~/.cache/huggingface:/root/.cache/huggingface -p 8000:8000 wurstdeploy/vllm:myvllm-kaln27-updated --max-model-len 1024 --model RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8 --disable-log-requests --max_num_batched_tokens 16384 --enable-chunked-prefill --gpu_memory_utilization 0.9
┌──────────────────────┬─────────────┬───────────────────┬─────────────┬───────────┬────────────────────┬────────────┬─────────────────────┬─────────────────────────────┬──────────────────────────────┐
│ Benchmark            │ QPS         │ E2E Latency (avg) │ TTFT (avg)  │ ITL (avg) │ Throughput         │ Error Rate │ Successful Requests │ Prompt tokens per req (avg) │ Decoded tokens per req (avg) │
├──────────────────────┼─────────────┼───────────────────┼─────────────┼───────────┼────────────────────┼────────────┼─────────────────────┼─────────────────────────────┼──────────────────────────────┤
│ warmup               │ 0.38 req/s  │ 2.65 sec          │ 132.31 ms   │ 14.57 ms  │ 65.25 tokens/sec   │ 0.00%      │ 12/12               │ 200.00                      │ 172.92                       │
│ [email protected]/s │ 30.73 req/s │ 22.39 sec         │ 14496.11 ms │ 43.93 ms  │ 5559.02 tokens/sec │ 0.00%      │ 3675/3675           │ 200.00                      │ 180.89                       │
└──────────────────────┴─────────────┴───────────────────┴─────────────┴───────────┴────────────────────┴────────────┴─────────────────────┴─────────────────────────────┴──────────────────────────────┘
# NEW build "wurstdeploy/vllm:kaln27updated20250627", based 27th June 2025 (after https://github.com/cyril23/vllm/pull/2)
sudo docker run --runtime nvidia --gpus all -v ~/.cache/huggingface:/root/.cache/huggingface -p 8000:8000 wurstdeploy/vllm:kaln27updated20250627 --max-model-len 1024 --model RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8 --disable-log-requests --max_num_batched_tokens 16384 --enable-chunked-prefill --gpu_memory_utilization 0.9
┌──────────────────────┬─────────────┬───────────────────┬─────────────┬───────────┬────────────────────┬────────────┬─────────────────────┬─────────────────────────────┬──────────────────────────────┐
│ Benchmark            │ QPS         │ E2E Latency (avg) │ TTFT (avg)  │ ITL (avg) │ Throughput         │ Error Rate │ Successful Requests │ Prompt tokens per req (avg) │ Decoded tokens per req (avg) │
├──────────────────────┼─────────────┼───────────────────┼─────────────┼───────────┼────────────────────┼────────────┼─────────────────────┼─────────────────────────────┼──────────────────────────────┤
│ warmup               │ 0.30 req/s  │ 3.30 sec          │ 48.31 ms    │ 16.07 ms  │ 61.58 tokens/sec   │ 0.00%      │ 9/9                 │ 200.00                      │ 203.22                       │
│ [email protected]/s │ 20.12 req/s │ 32.49 sec         │ 20778.30 ms │ 66.89 ms  │ 3591.34 tokens/sec │ 0.00%      │ 2399/2399           │ 200.00                      │ 178.51                       │
└──────────────────────┴─────────────┴───────────────────┴─────────────┴───────────┴────────────────────┴────────────┴─────────────────────┴─────────────────────────────┴──────────────────────────────┘

BF 16 performance has degraded since the older build, too:
image

Furthermore the very first request after starting up vLLM takes 30-60 seconds. Feels like PTX being compiled or something. This only happens on my June builds.

However I don't think it has anything to do with your code @kaln27 but rather some recent changes to the vllm main branch. Maybe I'm missing some important run time flags or built it wrong?

edit:

Furthermore the very first request after starting up vLLM takes 30-60 seconds. Feels like PTX being compiled or something. This only happens on my June builds.

However I don't think it has anything to do with your code @kaln27 but rather some recent changes to the vllm main branch.

Apparently #19336 is why this happened

Reduce wheel size by only building FA2 8.0+PTX instead of 8.0,9.0,10.0 etc.

unfortunately we need that PR until Pypi max. wheel size has been increased etc.

Furthermore according to #19336 (comment) I just need more warmup time. edit: more warmup does not help in my case, strange.

@UmakantKulkarni
Copy link

Hi @tlrmchlsmth,

May I know when this PR is expected to be merged? I’ve also verified @kaln27 's changes on an RTX 5090 (sm120), and they work well with the LLaMA 3.2 11B Vision model.

Thanks!

@waltstephen
Copy link

Please check and merge the PR ASAP, this is very useful for the people using 50 series and black wall....

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies for missing this PR, thanks for the kernel support! This looks reasonable to me, but could you share an e2e accuracy eval to make sure the kernel runs properly? Typically we use gsm8k on lm-eval

@kaln27
Copy link
Contributor Author

kaln27 commented Jul 1, 2025

Apologies for missing this PR, thanks for the kernel support! This looks reasonable to me, but could you share an e2e accuracy eval to make sure the kernel runs properly? Typically we use gsm8k on lm-eval

@mgoin thanks for your reply. I have download (Qwen2.5-3B-FP8-dynamic)[https://huggingface.co/RedHatAI/Qwen2.5-3B-FP8-dynamic] and run benchmark on gsm8k. The vllm that I use was build yesterday.

nvidia-smi

nvidia-smi 
Tue Jul  1 10:38:14 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 575.57.08              Driver Version: 575.57.08      CUDA Version: 12.9     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 5070 Ti     Off |   00000000:01:00.0 Off |                  N/A |
| 53%   32C    P0             22W /  300W |       0MiB /  16303MiB |      6%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+

Results

lm_eval \                                                                                                                                                  
  --model vllm \                                                                                                                                               
  --model_args pretrained="/data/models/RedHatAI/Qwen2.5-3B-FP8-dynamic",dtype=auto,gpu_memory_utilization=0.9,add_bos_token=True,max_model_len=4096,enable_chu
nked_prefill=True,tensor_parallel_size=1 \                                                                                                                     
  --tasks gsm8k \                                                                                                                                                --batch_size auto                                                            
INFO 07-01 10:26:54 [__init__.py:244] Automatically detected platform cuda.                                                                                    
2025-07-01:10:27:19 INFO     [__main__:440] Selected Tasks: ['gsm8k']                                                                                          
2025-07-01:10:27:19 INFO     [evaluator:185] Setting random seed to 0 | Setting numpy seed to 1234 | Setting torch manual seed to 1234 | Setting fewshot manual seed to 1234                                                                                                                                                  
2025-07-01:10:27:19 INFO     [evaluator:223] Initializing vllm model, with arguments: {'pretrained': '/data/models/RedHatAI/Qwen2.5-3B-FP8-dynamic', 'dtype': '
auto', 'gpu_memory_utilization': 0.9, 'add_bos_token': True, 'max_model_len': 4096, 'enable_chunked_prefill': True, 'tensor_parallel_size': 1}
INFO 07-01 10:27:28 [config.py:831] This model supports multiple tasks: {'embed', 'generate', 'classify', 'score', 'reward'}. Defaulting to 'generate'.
INFO 07-01 10:27:28 [config.py:1444] Using max model len 4096                                                                                         [23/1873]
INFO 07-01 10:27:29 [config.py:2197] Chunked prefill is enabled with max_num_batched_tokens=8192.                                                              
INFO 07-01 10:27:30 [core.py:460] Waiting for init message from front-end.                                                                                     
INFO 07-01 10:27:30 [core.py:70] Initializing a V1 LLM engine (v0.1.dev7202+g7414eb0.d20250630) with config: model='/data/models/RedHatAI/Qwen2.5-3B-FP8-dynami
c', speculative_config=None, tokenizer='/data/models/RedHatAI/Qwen2.5-3B-FP8-dynamic', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_
neuron_config={}, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=4096, download_dir=None, load_format=auto, tensor_paralle
l_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=compressed-tensors, enforce_eager=False, kv_cache_dtype=auto,  device_config=
cuda, decoding_config=DecodingConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_backend=''), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=1234, ser
ved_model_name=/data/models/RedHatAI/Qwen2.5-3B-FP8-dynamic, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill
_enabled=True, use_async_output_proc=True, pooler_config=None, compilation_config={"level":3,"debug_dump_path":"","cache_dir":"","backend":"","custom_ops":["no
ne"],"splitting_ops":["vllm.unified_attention","vllm.unified_attention_with_output"],"use_inductor":true,"compile_sizes":[],"inductor_compile_config":{"enable_
auto_functionalized_v2":false},"inductor_passes":{},"use_cudagraph":true,"cudagraph_num_of_warmups":1,"cudagraph_capture_sizes":[512,504,496,488,480,472,464,45
6,448,440,432,424,416,408,400,392,384,376,368,360,352,344,336,328,320,312,304,296,288,280,272,264,256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,1
36,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],"cudagraph_copy_inputs":false,"full_cuda_graph":false,"max_capture_size":512,"local_cache_dir":null}                                                                             
2025-07-01 10:27:30,718 - INFO - flashinfer.jit: Prebuilt kernels not found, using JIT backend                                                                 
WARNING 07-01 10:27:31 [utils.py:2756] Methods determine_num_available_blocks,device_config,get_cache_block_size_bytes not implemented in <vllm.v1.worker.gpu_w
orker.Worker object at 0x7416a73657e0>                                                                                                                         INFO 07-01 10:27:32 [parallel_state.py:1072] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0                                  
INFO 07-01 10:27:32 [topk_topp_sampler.py:49] Using FlashInfer for top-p & top-k sampling.                                                                     
INFO 07-01 10:27:32 [gpu_model_runner.py:1633] Starting to load model /data/models/RedHatAI/Qwen2.5-3B-FP8-dynamic...                         
INFO 07-01 10:27:32 [gpu_model_runner.py:1638] Loading model from scratch...                                                                                   
INFO 07-01 10:27:33 [cuda.py:259] Using Flash Attention backend on V1 engine.
Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]                                                                                   
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:29<00:00, 29.10s/it]                                                                           
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:29<00:00, 29.10s/it]                                                                           
                                                                                       
INFO 07-01 10:28:02 [default_loader.py:272] Loading weights took 29.20 seconds
INFO 07-01 10:28:02 [gpu_model_runner.py:1662] Model loading took 3.2290 GiB and 29.655233 seconds
INFO 07-01 10:28:17 [backends.py:508] Using cache directory: /data/liaojuncheng/.cache/vllm/torch_compile_cache/892cdfc123/rank_0_0/backbone for vLLM's torch.compile
INFO 07-01 10:28:17 [backends.py:519] Dynamo bytecode transform time: 14.21 s
INFO 07-01 10:28:20 [backends.py:181] Cache the graph of shape None for later use
^[INFO 07-01 10:28:51 [backends.py:193] Compiling a graph for general shape takes 33.59 s
INFO 07-01 10:29:06 [monitor.py:34] torch.compile takes 47.79 s in total
2025-07-01 10:29:06,431 - INFO - flashinfer.jit: Loading JIT ops: sampling
/data/liaojuncheng/miniconda3/envs/llm50xx/lib/python3.10/site-packages/torch/utils/cpp_extension.py:2356: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. 
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
  warnings.warn(
/data/liaojuncheng/miniconda3/envs/llm50xx/lib/python3.10/site-packages/torch/utils/cpp_extension.py:2356: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. 
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
  warnings.warn(
2025-07-01 10:29:06,672 - INFO - flashinfer.jit: Finished loading JIT ops: sampling
INFO 07-01 10:29:07 [gpu_worker.py:232] Available KV cache memory: 9.87 GiB
INFO 07-01 10:29:07 [kv_cache_utils.py:716] GPU KV cache size: 287,376 tokens
INFO 07-01 10:29:07 [kv_cache_utils.py:720] Maximum concurrency for 4,096 tokens per request: 70.16x
WARNING 07-01 10:29:07 [utils.py:101] Unable to detect current VLLM config. Defaulting to NHD kv cache layout.
Capturing CUDA graphs: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [00:20<00:00,  3.35it/s]
INFO 07-01 10:29:27 [gpu_model_runner.py:2092] Graph capturing finished in 20 secs, took 0.72 GiB
INFO 07-01 10:29:27 [core.py:173] init engine (profile, create kv cache, warmup model) took 85.11 seconds
2025-07-01:10:29:41 INFO     [evaluator:286] gsm8k: Using gen_kwargs: {'until': ['Question:', '</s>', '<|im_end|>'], 'do_sample': False, 'temperature': 0.0}
2025-07-01:10:29:41 INFO     [api.task:434] Building contexts for gsm8k on rank 0...
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [00:03<00:00, 350.59it/s]
2025-07-01:10:29:45 INFO     [evaluator:559] Running generate_until requests
Adding requests: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [00:00<00:00, 7570.54it/s]Processed prompts: 100%|██████████████████████████████████████████| 1319/1319 [02:57<00:00,  7.43it/s, est. speed input: 7377.40 toks/s, output: 909.31 toks/s]
Running generate_until requests: 100%|█████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [02:57<00:00,  7.42it/s]
2025-07-01:10:32:49 INFO     [loggers.evaluation_tracker:272] Output path not provided, skipping saving results aggregated
vllm (pretrained=/data/models/RedHatAI/Qwen2.5-3B-FP8-dynamic,dtype=auto,gpu_memory_utilization=0.9,add_bos_token=True,max_model_len=4096,enable_chunked_prefill=True,tensor_parallel_size=1), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.7324|±  |0.0122|
|     |       |strict-match    |     5|exact_match|↑  |0.6603|±  |0.0130|

You can see the vllm version is (v0.1.dev7202+g7414eb0.d20250630) which I build yesterday. The quantization method is compressed-tensors.
The result is 0.6603 in 5 shots exact match. In website https://huggingface.co/RedHatAI/Qwen2.5-3B-FP8-dynamic the FP8 Dynamic model's result is 63.91, which mean the kernel runs properly.

@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 1, 2025
@mgoin mgoin changed the title [NVIDIA] Support Cutlass w8a8 for Blackwell Geforce GPUs (sm120) [NVIDIA] Support Cutlass w8a8 FP8 for Blackwell Geforce GPUs (sm120) Jul 1, 2025
CUDA_ARCHS "${SCALED_MM_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM120=1")
# Let scaled_mm_c2x know it doesn't need to build these arches
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this comment is incorrect

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just copy from the top one. If you think that's incorrect you can delete it.

Copy link
Contributor Author

@kaln27 kaln27 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found that other cutlass_scaled_mm (use cutlass 3.0) in cmake also have this comment.
LGTM

CUDA_ARCHS "${SCALED_MM_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM120=1")
# Let scaled_mm_c2x know it doesn't need to build these arches
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just copy from the top one. If you think that's incorrect you can delete it.

@mgoin mgoin merged commit 9e5552a into vllm-project:main Jul 2, 2025
113 checks passed
huydhn pushed a commit to huydhn/vllm that referenced this pull request Jul 8, 2025
LyrisZhong pushed a commit to LyrisZhong/vllm that referenced this pull request Jul 23, 2025
avigny pushed a commit to avigny/vllm that referenced this pull request Jul 31, 2025
Pradyun92 pushed a commit to Pradyun92/vllm that referenced this pull request Aug 6, 2025
jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build kernel qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Installation]: Dual 5090's (sm120, cu128) Issues Running vLLM
6 participants