Skip to content

[Bug]: Running Running **Qwen-2.5-VL-72B-Instruct-FP8-Dynamic** with vLLM 0.9.1/0.9.2-dev on an **RTX 6000 Blackwell (96 GB)** throws RuntimeError: Expected a.dtype() == torch::kInt8 to be true, but got false at torch.ops._C.cutlass_scaled_mm #20221

@waltstephen

Description

@waltstephen

Your current environment

==============================
System Info

OS : Ubuntu 22.04.5 LTS (x86_64)
GCC version : (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version : Could not collect
CMake version : version 4.0.3
Libc version : glibc-2.35

==============================
PyTorch Info

PyTorch version : 2.7.0+cu128
Is debug build : False
CUDA used to build PyTorch : 12.8
ROCM used to build PyTorch : N/A

==============================
Python Environment

Python version : 3.10.18 | packaged by conda-forge | (main, Jun 4 2025, 14:45:41) [GCC 13.3.0] (64-bit runtime)
Python platform : Linux-6.8.0-60-generic-x86_64-with-glibc2.35

==============================
CUDA / GPU Info

Is CUDA available : True
CUDA runtime version : 12.8.93
CUDA_MODULE_LOADING set to : LAZY
GPU models and configuration : GPU 0: NVIDIA RTX PRO 6000 Blackwell Workstation Edition
Nvidia driver version : 575.57.08
cuDNN version : Could not collect
HIP runtime version : N/A
MIOpen runtime version : N/A
Is XNNPACK available : True

==============================
CPU Info

Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 160
On-line CPU(s) list: 0-159
Vendor ID: AuthenticAMD
Model name: AMD EPYC 9V74 80-Core Processor
CPU family: 25
Model: 17
Thread(s) per core: 2
Core(s) per socket: 80
Socket(s): 1
Stepping: 1
Frequency boost: enabled
CPU max MHz: 3701.9529
CPU min MHz: 1500.0000
BogoMIPS: 5192.04
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc amd_ibpb_ret arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d sev sev_es debug_swap
Virtualization: AMD-V
L1d cache: 2.5 MiB (80 instances)
L1i cache: 2.5 MiB (80 instances)
L2 cache: 80 MiB (80 instances)
L3 cache: 320 MiB (10 instances)
NUMA node(s): 1
NUMA node0 CPU(s): 0-159
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Mitigation; Safe RET
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

==============================
Versions of relevant libraries

[pip3] numpy==2.2.6
[pip3] nvidia-cublas-cu12==12.8.3.14
[pip3] nvidia-cuda-cupti-cu12==12.8.57
[pip3] nvidia-cuda-nvrtc-cu12==12.8.61
[pip3] nvidia-cuda-runtime-cu12==12.8.57
[pip3] nvidia-cudnn-cu12==9.7.1.26
[pip3] nvidia-cufft-cu12==11.3.3.41
[pip3] nvidia-cufile-cu12==1.13.0.11
[pip3] nvidia-curand-cu12==10.3.9.55
[pip3] nvidia-cusolver-cu12==11.7.2.55
[pip3] nvidia-cusparse-cu12==12.5.7.53
[pip3] nvidia-cusparselt-cu12==0.6.3
[pip3] nvidia-nccl-cu12==2.26.2
[pip3] nvidia-nvjitlink-cu12==12.8.61
[pip3] nvidia-nvtx-cu12==12.8.55
[pip3] pyzmq==27.0.0
[pip3] torch==2.7.0+cu128
[pip3] torchaudio==2.7.0+cu128
[pip3] torchvision==0.22.0+cu128
[pip3] transformers==4.52.4
[pip3] triton==3.3.0
[conda] cuda-cccl_linux-64 12.8.90 0 nvidia
[conda] cuda-command-line-tools 12.8.1 0 nvidia
[conda] cuda-compiler 12.8.1 0 nvidia
[conda] cuda-cudart 12.8.90 0 nvidia
[conda] cuda-cudart-dev 12.8.90 0 nvidia
[conda] cuda-cudart-dev_linux-64 12.8.90 0 nvidia
[conda] cuda-cudart-static 12.8.90 0 nvidia
[conda] cuda-cudart-static_linux-64 12.8.90 0 nvidia
[conda] cuda-cudart_linux-64 12.8.90 0 nvidia
[conda] cuda-cuobjdump 12.8.90 0 nvidia
[conda] cuda-cupti 12.8.90 0 nvidia
[conda] cuda-cupti-dev 12.8.90 0 nvidia
[conda] cuda-cuxxfilt 12.8.90 0 nvidia
[conda] cuda-driver-dev 12.8.90 0 nvidia
[conda] cuda-driver-dev_linux-64 12.8.90 0 nvidia
[conda] cuda-gdb 12.8.90 0 nvidia
[conda] cuda-libraries 12.8.1 0 nvidia
[conda] cuda-libraries-dev 12.8.1 0 nvidia
[conda] cuda-nsight 12.8.90 0 nvidia
[conda] cuda-nvcc 12.8.93 0 nvidia
[conda] cuda-nvcc_linux-64 12.8.93 0 nvidia
[conda] cuda-nvdisasm 12.8.90 0 nvidia
[conda] cuda-nvml-dev 12.8.90 0 nvidia
[conda] cuda-nvprof 12.8.90 0 nvidia
[conda] cuda-nvprune 12.8.90 0 nvidia
[conda] cuda-nvrtc 12.8.93 0 nvidia
[conda] cuda-nvrtc-dev 12.8.93 0 nvidia
[conda] cuda-nvtx 12.8.90 0 nvidia
[conda] cuda-nvtx-dev 12.8.90 0 nvidia
[conda] cuda-nvvp 12.8.93 0 nvidia
[conda] cuda-opencl 12.8.90 0 nvidia
[conda] cuda-opencl-dev 12.8.90 0 nvidia
[conda] cuda-profiler-api 12.8.90 0 nvidia
[conda] cuda-sanitizer-api 12.8.93 0 nvidia
[conda] cuda-toolkit 12.8.1 0 nvidia
[conda] cuda-tools 12.8.1 0 nvidia
[conda] cuda-version 12.8 3 nvidia
[conda] cuda-visual-tools 12.8.1 0 nvidia
[conda] gds-tools 1.13.1.3 0 nvidia
[conda] libcublas 12.8.4.1 0 nvidia
[conda] libcublas-dev 12.8.4.1 0 nvidia
[conda] libcufft 11.3.3.83 0 nvidia
[conda] libcufft-dev 11.3.3.83 0 nvidia
[conda] libcufile 1.13.1.3 0 nvidia
[conda] libcufile-dev 1.13.1.3 0 nvidia
[conda] libcurand 10.3.9.90 0 nvidia
[conda] libcurand-dev 10.3.9.90 0 nvidia
[conda] libcusolver 11.7.3.90 0 nvidia
[conda] libcusolver-dev 11.7.3.90 0 nvidia
[conda] libcusparse 12.5.8.93 0 nvidia
[conda] libcusparse-dev 12.5.8.93 0 nvidia
[conda] libnpp 12.3.3.100 0 nvidia
[conda] libnpp-dev 12.3.3.100 0 nvidia
[conda] libnvfatbin 12.8.90 0 nvidia
[conda] libnvfatbin-dev 12.8.90 0 nvidia
[conda] libnvjitlink 12.8.93 1 nvidia
[conda] libnvjitlink-dev 12.8.93 1 nvidia
[conda] libnvjpeg 12.3.5.92 0 nvidia
[conda] libnvjpeg-dev 12.3.5.92 0 nvidia
[conda] nsight-compute 2025.1.1.2 0 nvidia
[conda] numpy 2.2.6 pypi_0 pypi
[conda] nvidia-cublas-cu12 12.8.3.14 pypi_0 pypi
[conda] nvidia-cuda-cupti-cu12 12.8.57 pypi_0 pypi
[conda] nvidia-cuda-nvrtc-cu12 12.8.61 pypi_0 pypi
[conda] nvidia-cuda-runtime-cu12 12.8.57 pypi_0 pypi
[conda] nvidia-cudnn-cu12 9.7.1.26 pypi_0 pypi
[conda] nvidia-cufft-cu12 11.3.3.41 pypi_0 pypi
[conda] nvidia-cufile-cu12 1.13.0.11 pypi_0 pypi
[conda] nvidia-curand-cu12 10.3.9.55 pypi_0 pypi
[conda] nvidia-cusolver-cu12 11.7.2.55 pypi_0 pypi
[conda] nvidia-cusparse-cu12 12.5.7.53 pypi_0 pypi
[conda] nvidia-cusparselt-cu12 0.6.3 pypi_0 pypi
[conda] nvidia-nccl-cu12 2.26.2 pypi_0 pypi
[conda] nvidia-nvjitlink-cu12 12.8.61 pypi_0 pypi
[conda] nvidia-nvtx-cu12 12.8.55 pypi_0 pypi
[conda] pyzmq 27.0.0 pypi_0 pypi
[conda] torch 2.7.0+cu128 pypi_0 pypi
[conda] torchaudio 2.7.0+cu128 pypi_0 pypi
[conda] torchvision 0.22.0+cu128 pypi_0 pypi
[conda] transformers 4.52.4 pypi_0 pypi
[conda] triton 3.3.0 pypi_0 pypi

==============================
vLLM Info

ROCM Version : Could not collect
Neuron SDK Version : N/A
vLLM Version : 0.9.2.dev311+g4d3669368 (git sha: 4d36693)
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0 CPU Affinity NUMA Affinity GPU NUMA ID
GPU0 X 0-159 0 N/A

Legend:

X = Self
SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
PIX = Connection traversing at most a single PCIe bridge
NV# = Connection traversing a bonded set of # NVLinks

==============================
Environment Variables

NCCL_CUMEM_ENABLE=0
PYTORCH_NVML_BASED_CUDA_CHECK=1
TORCHINDUCTOR_COMPILE_THREADS=1
CUDA_MODULE_LOADING=LAZY

🐛 Describe the bug

during the first dummy forward-pass that profiles KV-cache size.
The error bubbles up from w8a8_scaled_mm_func → cutlass_scaled_mm inside the compressed-tensors FP8 path and prevents the engine core from starting.

✔️ Reproduction

import os
from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset

# ✱ enforcing eager just to rule out torch.compile
os.environ["VLLM_ENFORCE_EAGER"] = "1"

llm = LLM(
    model="./Qwen2.5-VL-72B-Instruct-FP8-Dynamic",
    trust_remote_code=True,
    max_model_len=4096,         # same as config.json
)

question = "What’s in this picture?"
inputs = {
    "prompt": f"<|user|>\n<|image_1|>\n{question}<|end|>\n<|assistant|>\n",
    "multi_modal_data": {
        "image": ImageAsset("cherry_blossom").pil_image.convert("RGB")
    },
}

llm.generate(inputs, SamplingParams(temperature=0.2, max_tokens=64))

### Error

```bash

INFO 06-29 08:28:45 [__init__.py:244] Automatically detected platform cuda.
INFO 06-29 08:28:53 [config.py:853] This model supports multiple tasks: {'embed', 'generate', 'score', 'classify', 'reward'}. Defaulting to 'generate'.
INFO 06-29 08:28:53 [config.py:1467] Using max model len 4096
INFO 06-29 08:28:54 [config.py:2267] Chunked prefill is enabled with max_num_batched_tokens=16384.
WARNING 06-29 08:28:54 [config.py:2303] max_num_batched_tokens (16384) exceeds max_num_seqs* max_model_len (8192). This may lead to unexpected behavior.
WARNING 06-29 08:28:54 [config.py:2303] max_num_batched_tokens (16384) exceeds max_num_seqs* max_model_len (8192). This may lead to unexpected behavior.
WARNING 06-29 08:28:54 [cuda.py:102] To see benefits of async output processing, enable CUDA graph. Since, enforce-eager is enabled, async output processor cannot be used
INFO 06-29 08:28:54 [core.py:459] Waiting for init message from front-end.
INFO 06-29 08:28:54 [core.py:69] Initializing a V1 LLM engine (v0.9.2.dev311+g4d3669368) with config: model='./Qwen2.5-VL-72B-Instruct-FP8-Dynamic', speculative_config=None, tokenizer='./Qwen2.5-VL-72B-Instruct-FP8-Dynamic', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config={}, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=4096, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=compressed-tensors, enforce_eager=True, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_backend=''), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=./Qwen2.5-VL-72B-Instruct-FP8-Dynamic, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=False, pooler_config=None, compilation_config={"level":0,"debug_dump_path":"","cache_dir":"","backend":"","custom_ops":[],"splitting_ops":[],"use_inductor":false,"compile_sizes":[],"inductor_compile_config":{"enable_auto_functionalized_v2":false},"inductor_passes":{},"use_cudagraph":true,"cudagraph_num_of_warmups":0,"cudagraph_capture_sizes":[],"cudagraph_copy_inputs":false,"full_cuda_graph":false,"max_capture_size":0,"local_cache_dir":null}
WARNING 06-29 08:28:55 [utils.py:2753] Methods determine_num_available_blocks,device_config,get_cache_block_size_bytes not implemented in <vllm.v1.worker.gpu_worker.Worker object at 0x7740b7e2ab90>
INFO 06-29 08:28:55 [parallel_state.py:1076] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0
WARNING 06-29 08:28:56 [profiling.py:266] The sequence length (4096) is smaller than the pre-defined wosrt-case total number of multimodal tokens (32768). This may cause certain multi-modal inputs to fail during inference. To avoid this, you should increase `max_model_len` or reduce `mm_counts`.
WARNING 06-29 08:28:56 [topk_topp_sampler.py:59] FlashInfer is not available. Falling back to the PyTorch-native implementation of top-p & top-k sampling. For the best performance, please install FlashInfer.
INFO 06-29 08:28:56 [gpu_model_runner.py:1725] Starting to load model ./Qwen2.5-VL-72B-Instruct-FP8-Dynamic...
INFO 06-29 08:28:56 [gpu_model_runner.py:1730] Loading model from scratch...
WARNING 06-29 08:28:56 [vision.py:91] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend.
WARNING 06-29 08:28:56 [config.py:2303] max_num_batched_tokens (16384) exceeds max_num_seqs* max_model_len (8192). This may lead to unexpected behavior.
INFO 06-29 08:28:56 [cuda.py:270] Using Flash Attention backend on V1 engine.
Loading safetensors checkpoint shards:   0% Completed | 0/16 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:   6% Completed | 1/16 [00:00<00:09,  1.56it/s]
Loading safetensors checkpoint shards:  12% Completed | 2/16 [00:01<00:10,  1.34it/s]
Loading safetensors checkpoint shards:  19% Completed | 3/16 [00:02<00:09,  1.33it/s]
Loading safetensors checkpoint shards:  25% Completed | 4/16 [00:02<00:08,  1.34it/s]
Loading safetensors checkpoint shards:  31% Completed | 5/16 [00:03<00:08,  1.36it/s]
Loading safetensors checkpoint shards:  38% Completed | 6/16 [00:04<00:07,  1.38it/s]
Loading safetensors checkpoint shards:  44% Completed | 7/16 [00:05<00:06,  1.37it/s]
Loading safetensors checkpoint shards:  50% Completed | 8/16 [00:05<00:05,  1.36it/s]
Loading safetensors checkpoint shards:  56% Completed | 9/16 [00:06<00:05,  1.39it/s]
Loading safetensors checkpoint shards:  62% Completed | 10/16 [00:07<00:04,  1.39it/s]
Loading safetensors checkpoint shards:  69% Completed | 11/16 [00:07<00:03,  1.40it/s]
Loading safetensors checkpoint shards:  75% Completed | 12/16 [00:08<00:02,  1.41it/s]
Loading safetensors checkpoint shards:  81% Completed | 13/16 [00:09<00:02,  1.42it/s]
Loading safetensors checkpoint shards:  88% Completed | 14/16 [00:10<00:01,  1.40it/s]
Loading safetensors checkpoint shards:  94% Completed | 15/16 [00:10<00:00,  1.41it/s]
Loading safetensors checkpoint shards: 100% Completed | 16/16 [00:11<00:00,  1.48it/s]
Loading safetensors checkpoint shards: 100% Completed | 16/16 [00:11<00:00,  1.40it/s]

INFO 06-29 08:29:08 [default_loader.py:272] Loading weights took 11.51 seconds
INFO 06-29 08:29:08 [gpu_model_runner.py:1756] Model loading took 71.5736 GiB and 11.800011 seconds
INFO 06-29 08:29:08 [gpu_model_runner.py:2195] Encoder cache will be initialized with a budget of 16384 tokens, and profiled with 1 image items of the maximum feature size.
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
You have video processor config saved in `preprocessor.json` file which is deprecated. Video processor configs should be saved in their own `video_preprocessor.json` file. You can rename the file or load and save the processor back which renames it automatically. Loading from `preprocessor.json` will be removed in v5.0.
ERROR 06-29 08:29:12 [core.py:519] EngineCore failed to start.
ERROR 06-29 08:29:12 [core.py:519] Traceback (most recent call last):
ERROR 06-29 08:29:12 [core.py:519]   File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/v1/engine/core.py", line 510, in run_engine_core
ERROR 06-29 08:29:12 [core.py:519]     engine_core = EngineCoreProc(*args, **kwargs)
ERROR 06-29 08:29:12 [core.py:519]   File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/v1/engine/core.py", line 394, in __init__
ERROR 06-29 08:29:12 [core.py:519]     super().__init__(vllm_config, executor_class, log_stats,
ERROR 06-29 08:29:12 [core.py:519]   File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/v1/engine/core.py", line 82, in __init__
ERROR 06-29 08:29:12 [core.py:519]     self._initialize_kv_caches(vllm_config)
ERROR 06-29 08:29:12 [core.py:519]   File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/v1/engine/core.py", line 142, in _initialize_kv_caches
ERROR 06-29 08:29:12 [core.py:519]     available_gpu_memory = self.model_executor.determine_available_memory()
ERROR 06-29 08:29:12 [core.py:519]   File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/v1/executor/abstract.py", line 76, in determine_available_memory
ERROR 06-29 08:29:12 [core.py:519]     output = self.collective_rpc("determine_available_memory")
ERROR 06-29 08:29:12 [core.py:519]   File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/executor/uniproc_executor.py", line 57, in collective_rpc
ERROR 06-29 08:29:12 [core.py:519]     answer = run_method(self.driver_worker, method, args, kwargs)
ERROR 06-29 08:29:12 [core.py:519]   File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/utils.py", line 2687, in run_method
ERROR 06-29 08:29:12 [core.py:519]     return func(*args, **kwargs)
ERROR 06-29 08:29:12 [core.py:519]   File "/home/jusheng/miniconda3/envs/qwenfp8/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
ERROR 06-29 08:29:12 [core.py:519]     return func(*args, **kwargs)
ERROR 06-29 08:29:12 [core.py:519]   File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/v1/worker/gpu_worker.py", line 210, in determine_available_memory
ERROR 06-29 08:29:12 [core.py:519]     self.model_runner.profile_run()
ERROR 06-29 08:29:12 [core.py:519]   File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/v1/worker/gpu_model_runner.py", line 2231, in profile_run
ERROR 06-29 08:29:12 [core.py:519]     = self._dummy_run(self.max_num_tokens, is_profile=True)
ERROR 06-29 08:29:12 [core.py:519]   File "/home/jusheng/miniconda3/envs/qwenfp8/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
ERROR 06-29 08:29:12 [core.py:519]     return func(*args, **kwargs)
ERROR 06-29 08:29:12 [core.py:519]   File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/v1/worker/gpu_model_runner.py", line 2012, in _dummy_run
ERROR 06-29 08:29:12 [core.py:519]     outputs = model(
ERROR 06-29 08:29:12 [core.py:519]   File "/home/jusheng/miniconda3/envs/qwenfp8/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
ERROR 06-29 08:29:12 [core.py:519]     return self._call_impl(*args, **kwargs)
ERROR 06-29 08:29:12 [core.py:519]   File "/home/jusheng/miniconda3/envs/qwenfp8/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
ERROR 06-29 08:29:12 [core.py:519]     return forward_call(*args, **kwargs)
ERROR 06-29 08:29:12 [core.py:519]   File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/model_executor/models/qwen2_5_vl.py", line 1136, in forward
ERROR 06-29 08:29:12 [core.py:519]     hidden_states = self.language_model.model(
ERROR 06-29 08:29:12 [core.py:519]   File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/compilation/decorators.py", line 173, in __call__
ERROR 06-29 08:29:12 [core.py:519]     return self.forward(*args, **kwargs)
ERROR 06-29 08:29:12 [core.py:519]   File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/model_executor/models/qwen2.py", line 354, in forward
ERROR 06-29 08:29:12 [core.py:519]     hidden_states, residual = layer(
ERROR 06-29 08:29:12 [core.py:519]   File "/home/jusheng/miniconda3/envs/qwenfp8/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
ERROR 06-29 08:29:12 [core.py:519]     return self._call_impl(*args, **kwargs)
ERROR 06-29 08:29:12 [core.py:519]   File "/home/jusheng/miniconda3/envs/qwenfp8/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
ERROR 06-29 08:29:12 [core.py:519]     return forward_call(*args, **kwargs)
ERROR 06-29 08:29:12 [core.py:519]   File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/model_executor/models/qwen2.py", line 253, in forward
ERROR 06-29 08:29:12 [core.py:519]     hidden_states = self.self_attn(
ERROR 06-29 08:29:12 [core.py:519]   File "/home/jusheng/miniconda3/envs/qwenfp8/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
ERROR 06-29 08:29:12 [core.py:519]     return self._call_impl(*args, **kwargs)
ERROR 06-29 08:29:12 [core.py:519]   File "/home/jusheng/miniconda3/envs/qwenfp8/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
ERROR 06-29 08:29:12 [core.py:519]     return forward_call(*args, **kwargs)
ERROR 06-29 08:29:12 [core.py:519]   File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/model_executor/models/qwen2.py", line 180, in forward
ERROR 06-29 08:29:12 [core.py:519]     qkv, _ = self.qkv_proj(hidden_states)
ERROR 06-29 08:29:12 [core.py:519]   File "/home/jusheng/miniconda3/envs/qwenfp8/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
ERROR 06-29 08:29:12 [core.py:519]     return self._call_impl(*args, **kwargs)
ERROR 06-29 08:29:12 [core.py:519]   File "/home/jusheng/miniconda3/envs/qwenfp8/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
ERROR 06-29 08:29:12 [core.py:519]     return forward_call(*args, **kwargs)
ERROR 06-29 08:29:12 [core.py:519]   File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/model_executor/layers/linear.py", line 487, in forward
ERROR 06-29 08:29:12 [core.py:519]     output_parallel = self.quant_method.apply(self, input_, bias)
ERROR 06-29 08:29:12 [core.py:519]   File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py", line 633, in apply
ERROR 06-29 08:29:12 [core.py:519]     return scheme.apply_weights(layer, x, bias=bias)
ERROR 06-29 08:29:12 [core.py:519]   File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py", line 145, in apply_weights
ERROR 06-29 08:29:12 [core.py:519]     return self.fp8_linear.apply(input=x,
ERROR 06-29 08:29:12 [core.py:519]   File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/model_executor/layers/quantization/utils/w8a8_utils.py", line 374, in apply
ERROR 06-29 08:29:12 [core.py:519]     return w8a8_scaled_mm_func(qinput=qinput,
ERROR 06-29 08:29:12 [core.py:519]   File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/model_executor/layers/quantization/utils/w8a8_utils.py", line 143, in cutlass_w8a8_scaled_mm
ERROR 06-29 08:29:12 [core.py:519]     output = ops.cutlass_scaled_mm(qinput,
ERROR 06-29 08:29:12 [core.py:519]   File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/_custom_ops.py", line 713, in cutlass_scaled_mm
ERROR 06-29 08:29:12 [core.py:519]     torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
ERROR 06-29 08:29:12 [core.py:519]   File "/home/jusheng/miniconda3/envs/qwenfp8/lib/python3.10/site-packages/torch/_ops.py", line 1158, in __call__
ERROR 06-29 08:29:12 [core.py:519]     return self._op(*args, **(kwargs or {}))
ERROR 06-29 08:29:12 [core.py:519] RuntimeError: Expected a.dtype() == torch::kInt8 to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)
Process EngineCore_0:
Traceback (most recent call last):
  File "/home/jusheng/miniconda3/envs/qwenfp8/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/home/jusheng/miniconda3/envs/qwenfp8/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/v1/engine/core.py", line 523, in run_engine_core
    raise e
  File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/v1/engine/core.py", line 510, in run_engine_core
    engine_core = EngineCoreProc(*args, **kwargs)
  File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/v1/engine/core.py", line 394, in __init__
    super().__init__(vllm_config, executor_class, log_stats,
  File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/v1/engine/core.py", line 82, in __init__
    self._initialize_kv_caches(vllm_config)
  File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/v1/engine/core.py", line 142, in _initialize_kv_caches
    available_gpu_memory = self.model_executor.determine_available_memory()
  File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/v1/executor/abstract.py", line 76, in determine_available_memory
    output = self.collective_rpc("determine_available_memory")
  File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/executor/uniproc_executor.py", line 57, in collective_rpc
    answer = run_method(self.driver_worker, method, args, kwargs)
  File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/utils.py", line 2687, in run_method
    return func(*args, **kwargs)
  File "/home/jusheng/miniconda3/envs/qwenfp8/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/v1/worker/gpu_worker.py", line 210, in determine_available_memory
    self.model_runner.profile_run()
  File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/v1/worker/gpu_model_runner.py", line 2231, in profile_run
    = self._dummy_run(self.max_num_tokens, is_profile=True)
  File "/home/jusheng/miniconda3/envs/qwenfp8/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/v1/worker/gpu_model_runner.py", line 2012, in _dummy_run
    outputs = model(
  File "/home/jusheng/miniconda3/envs/qwenfp8/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/jusheng/miniconda3/envs/qwenfp8/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/model_executor/models/qwen2_5_vl.py", line 1136, in forward
    hidden_states = self.language_model.model(
  File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/compilation/decorators.py", line 173, in __call__
    return self.forward(*args, **kwargs)
  File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/model_executor/models/qwen2.py", line 354, in forward
    hidden_states, residual = layer(
  File "/home/jusheng/miniconda3/envs/qwenfp8/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/jusheng/miniconda3/envs/qwenfp8/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/model_executor/models/qwen2.py", line 253, in forward
    hidden_states = self.self_attn(
  File "/home/jusheng/miniconda3/envs/qwenfp8/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/jusheng/miniconda3/envs/qwenfp8/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/model_executor/models/qwen2.py", line 180, in forward
    qkv, _ = self.qkv_proj(hidden_states)
  File "/home/jusheng/miniconda3/envs/qwenfp8/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/jusheng/miniconda3/envs/qwenfp8/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/model_executor/layers/linear.py", line 487, in forward
    output_parallel = self.quant_method.apply(self, input_, bias)
  File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py", line 633, in apply
    return scheme.apply_weights(layer, x, bias=bias)
  File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py", line 145, in apply_weights
    return self.fp8_linear.apply(input=x,
  File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/model_executor/layers/quantization/utils/w8a8_utils.py", line 374, in apply
    return w8a8_scaled_mm_func(qinput=qinput,
  File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/model_executor/layers/quantization/utils/w8a8_utils.py", line 143, in cutlass_w8a8_scaled_mm
    output = ops.cutlass_scaled_mm(qinput,
  File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/_custom_ops.py", line 713, in cutlass_scaled_mm
    torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
  File "/home/jusheng/miniconda3/envs/qwenfp8/lib/python3.10/site-packages/torch/_ops.py", line 1158, in __call__
    return self._op(*args, **(kwargs or {}))
RuntimeError: Expected a.dtype() == torch::kInt8 to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)
Traceback (most recent call last):
  File "/home/jusheng/jusheng_files/Qwen2.5VL_FP_infer/test.py", line 14, in <module>
    llm = LLM(
  File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/entrypoints/llm.py", line 263, in __init__
    self.llm_engine = LLMEngine.from_engine_args(
  File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/engine/llm_engine.py", line 501, in from_engine_args
    return engine_cls.from_vllm_config(
  File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/v1/engine/llm_engine.py", line 124, in from_vllm_config
    return cls(vllm_config=vllm_config,
  File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/v1/engine/llm_engine.py", line 101, in __init__
    self.engine_core = EngineCoreClient.make_client(
  File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/v1/engine/core_client.py", line 75, in make_client
    return SyncMPClient(vllm_config, executor_class, log_stats)
  File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/v1/engine/core_client.py", line 572, in __init__
    super().__init__(
  File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/v1/engine/core_client.py", line 433, in __init__
    self._init_engines_direct(vllm_config, local_only,
  File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/v1/engine/core_client.py", line 502, in _init_engines_direct
    self._wait_for_engine_startup(handshake_socket, input_address,
  File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/v1/engine/core_client.py", line 522, in _wait_for_engine_startup
    wait_for_engine_startup(
  File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/v1/utils.py", line 494, in wait_for_engine_startup
    raise RuntimeError("Engine core initialization failed. "
RuntimeError: Engine core initialization failed. See root cause above. Failed core proc(s): {}

The Core issue is **RuntimeError: Expected a.dtype() == torch::kInt8 to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)**.

I am now trying to use INT8 to avoid this fp8 quesion.


### Before submitting a new issue...

- [x] Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions