-
-
Notifications
You must be signed in to change notification settings - Fork 9.3k
Description
Your current environment
==============================
System Info
OS : Ubuntu 22.04.5 LTS (x86_64)
GCC version : (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version : Could not collect
CMake version : version 4.0.3
Libc version : glibc-2.35
==============================
PyTorch Info
PyTorch version : 2.7.0+cu128
Is debug build : False
CUDA used to build PyTorch : 12.8
ROCM used to build PyTorch : N/A
==============================
Python Environment
Python version : 3.10.18 | packaged by conda-forge | (main, Jun 4 2025, 14:45:41) [GCC 13.3.0] (64-bit runtime)
Python platform : Linux-6.8.0-60-generic-x86_64-with-glibc2.35
==============================
CUDA / GPU Info
Is CUDA available : True
CUDA runtime version : 12.8.93
CUDA_MODULE_LOADING set to : LAZY
GPU models and configuration : GPU 0: NVIDIA RTX PRO 6000 Blackwell Workstation Edition
Nvidia driver version : 575.57.08
cuDNN version : Could not collect
HIP runtime version : N/A
MIOpen runtime version : N/A
Is XNNPACK available : True
==============================
CPU Info
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 160
On-line CPU(s) list: 0-159
Vendor ID: AuthenticAMD
Model name: AMD EPYC 9V74 80-Core Processor
CPU family: 25
Model: 17
Thread(s) per core: 2
Core(s) per socket: 80
Socket(s): 1
Stepping: 1
Frequency boost: enabled
CPU max MHz: 3701.9529
CPU min MHz: 1500.0000
BogoMIPS: 5192.04
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc amd_ibpb_ret arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d sev sev_es debug_swap
Virtualization: AMD-V
L1d cache: 2.5 MiB (80 instances)
L1i cache: 2.5 MiB (80 instances)
L2 cache: 80 MiB (80 instances)
L3 cache: 320 MiB (10 instances)
NUMA node(s): 1
NUMA node0 CPU(s): 0-159
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Mitigation; Safe RET
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
==============================
Versions of relevant libraries
[pip3] numpy==2.2.6
[pip3] nvidia-cublas-cu12==12.8.3.14
[pip3] nvidia-cuda-cupti-cu12==12.8.57
[pip3] nvidia-cuda-nvrtc-cu12==12.8.61
[pip3] nvidia-cuda-runtime-cu12==12.8.57
[pip3] nvidia-cudnn-cu12==9.7.1.26
[pip3] nvidia-cufft-cu12==11.3.3.41
[pip3] nvidia-cufile-cu12==1.13.0.11
[pip3] nvidia-curand-cu12==10.3.9.55
[pip3] nvidia-cusolver-cu12==11.7.2.55
[pip3] nvidia-cusparse-cu12==12.5.7.53
[pip3] nvidia-cusparselt-cu12==0.6.3
[pip3] nvidia-nccl-cu12==2.26.2
[pip3] nvidia-nvjitlink-cu12==12.8.61
[pip3] nvidia-nvtx-cu12==12.8.55
[pip3] pyzmq==27.0.0
[pip3] torch==2.7.0+cu128
[pip3] torchaudio==2.7.0+cu128
[pip3] torchvision==0.22.0+cu128
[pip3] transformers==4.52.4
[pip3] triton==3.3.0
[conda] cuda-cccl_linux-64 12.8.90 0 nvidia
[conda] cuda-command-line-tools 12.8.1 0 nvidia
[conda] cuda-compiler 12.8.1 0 nvidia
[conda] cuda-cudart 12.8.90 0 nvidia
[conda] cuda-cudart-dev 12.8.90 0 nvidia
[conda] cuda-cudart-dev_linux-64 12.8.90 0 nvidia
[conda] cuda-cudart-static 12.8.90 0 nvidia
[conda] cuda-cudart-static_linux-64 12.8.90 0 nvidia
[conda] cuda-cudart_linux-64 12.8.90 0 nvidia
[conda] cuda-cuobjdump 12.8.90 0 nvidia
[conda] cuda-cupti 12.8.90 0 nvidia
[conda] cuda-cupti-dev 12.8.90 0 nvidia
[conda] cuda-cuxxfilt 12.8.90 0 nvidia
[conda] cuda-driver-dev 12.8.90 0 nvidia
[conda] cuda-driver-dev_linux-64 12.8.90 0 nvidia
[conda] cuda-gdb 12.8.90 0 nvidia
[conda] cuda-libraries 12.8.1 0 nvidia
[conda] cuda-libraries-dev 12.8.1 0 nvidia
[conda] cuda-nsight 12.8.90 0 nvidia
[conda] cuda-nvcc 12.8.93 0 nvidia
[conda] cuda-nvcc_linux-64 12.8.93 0 nvidia
[conda] cuda-nvdisasm 12.8.90 0 nvidia
[conda] cuda-nvml-dev 12.8.90 0 nvidia
[conda] cuda-nvprof 12.8.90 0 nvidia
[conda] cuda-nvprune 12.8.90 0 nvidia
[conda] cuda-nvrtc 12.8.93 0 nvidia
[conda] cuda-nvrtc-dev 12.8.93 0 nvidia
[conda] cuda-nvtx 12.8.90 0 nvidia
[conda] cuda-nvtx-dev 12.8.90 0 nvidia
[conda] cuda-nvvp 12.8.93 0 nvidia
[conda] cuda-opencl 12.8.90 0 nvidia
[conda] cuda-opencl-dev 12.8.90 0 nvidia
[conda] cuda-profiler-api 12.8.90 0 nvidia
[conda] cuda-sanitizer-api 12.8.93 0 nvidia
[conda] cuda-toolkit 12.8.1 0 nvidia
[conda] cuda-tools 12.8.1 0 nvidia
[conda] cuda-version 12.8 3 nvidia
[conda] cuda-visual-tools 12.8.1 0 nvidia
[conda] gds-tools 1.13.1.3 0 nvidia
[conda] libcublas 12.8.4.1 0 nvidia
[conda] libcublas-dev 12.8.4.1 0 nvidia
[conda] libcufft 11.3.3.83 0 nvidia
[conda] libcufft-dev 11.3.3.83 0 nvidia
[conda] libcufile 1.13.1.3 0 nvidia
[conda] libcufile-dev 1.13.1.3 0 nvidia
[conda] libcurand 10.3.9.90 0 nvidia
[conda] libcurand-dev 10.3.9.90 0 nvidia
[conda] libcusolver 11.7.3.90 0 nvidia
[conda] libcusolver-dev 11.7.3.90 0 nvidia
[conda] libcusparse 12.5.8.93 0 nvidia
[conda] libcusparse-dev 12.5.8.93 0 nvidia
[conda] libnpp 12.3.3.100 0 nvidia
[conda] libnpp-dev 12.3.3.100 0 nvidia
[conda] libnvfatbin 12.8.90 0 nvidia
[conda] libnvfatbin-dev 12.8.90 0 nvidia
[conda] libnvjitlink 12.8.93 1 nvidia
[conda] libnvjitlink-dev 12.8.93 1 nvidia
[conda] libnvjpeg 12.3.5.92 0 nvidia
[conda] libnvjpeg-dev 12.3.5.92 0 nvidia
[conda] nsight-compute 2025.1.1.2 0 nvidia
[conda] numpy 2.2.6 pypi_0 pypi
[conda] nvidia-cublas-cu12 12.8.3.14 pypi_0 pypi
[conda] nvidia-cuda-cupti-cu12 12.8.57 pypi_0 pypi
[conda] nvidia-cuda-nvrtc-cu12 12.8.61 pypi_0 pypi
[conda] nvidia-cuda-runtime-cu12 12.8.57 pypi_0 pypi
[conda] nvidia-cudnn-cu12 9.7.1.26 pypi_0 pypi
[conda] nvidia-cufft-cu12 11.3.3.41 pypi_0 pypi
[conda] nvidia-cufile-cu12 1.13.0.11 pypi_0 pypi
[conda] nvidia-curand-cu12 10.3.9.55 pypi_0 pypi
[conda] nvidia-cusolver-cu12 11.7.2.55 pypi_0 pypi
[conda] nvidia-cusparse-cu12 12.5.7.53 pypi_0 pypi
[conda] nvidia-cusparselt-cu12 0.6.3 pypi_0 pypi
[conda] nvidia-nccl-cu12 2.26.2 pypi_0 pypi
[conda] nvidia-nvjitlink-cu12 12.8.61 pypi_0 pypi
[conda] nvidia-nvtx-cu12 12.8.55 pypi_0 pypi
[conda] pyzmq 27.0.0 pypi_0 pypi
[conda] torch 2.7.0+cu128 pypi_0 pypi
[conda] torchaudio 2.7.0+cu128 pypi_0 pypi
[conda] torchvision 0.22.0+cu128 pypi_0 pypi
[conda] transformers 4.52.4 pypi_0 pypi
[conda] triton 3.3.0 pypi_0 pypi
==============================
vLLM Info
ROCM Version : Could not collect
Neuron SDK Version : N/A
vLLM Version : 0.9.2.dev311+g4d3669368 (git sha: 4d36693)
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0 CPU Affinity NUMA Affinity GPU NUMA ID
GPU0 X 0-159 0 N/A
Legend:
X = Self
SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
PIX = Connection traversing at most a single PCIe bridge
NV# = Connection traversing a bonded set of # NVLinks
==============================
Environment Variables
NCCL_CUMEM_ENABLE=0
PYTORCH_NVML_BASED_CUDA_CHECK=1
TORCHINDUCTOR_COMPILE_THREADS=1
CUDA_MODULE_LOADING=LAZY
🐛 Describe the bug
during the first dummy forward-pass that profiles KV-cache size.
The error bubbles up from w8a8_scaled_mm_func → cutlass_scaled_mm
inside the compressed-tensors FP8 path and prevents the engine core from starting.
✔️ Reproduction
import os
from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset
# ✱ enforcing eager just to rule out torch.compile
os.environ["VLLM_ENFORCE_EAGER"] = "1"
llm = LLM(
model="./Qwen2.5-VL-72B-Instruct-FP8-Dynamic",
trust_remote_code=True,
max_model_len=4096, # same as config.json
)
question = "What’s in this picture?"
inputs = {
"prompt": f"<|user|>\n<|image_1|>\n{question}<|end|>\n<|assistant|>\n",
"multi_modal_data": {
"image": ImageAsset("cherry_blossom").pil_image.convert("RGB")
},
}
llm.generate(inputs, SamplingParams(temperature=0.2, max_tokens=64))
### Error
```bash
INFO 06-29 08:28:45 [__init__.py:244] Automatically detected platform cuda.
INFO 06-29 08:28:53 [config.py:853] This model supports multiple tasks: {'embed', 'generate', 'score', 'classify', 'reward'}. Defaulting to 'generate'.
INFO 06-29 08:28:53 [config.py:1467] Using max model len 4096
INFO 06-29 08:28:54 [config.py:2267] Chunked prefill is enabled with max_num_batched_tokens=16384.
WARNING 06-29 08:28:54 [config.py:2303] max_num_batched_tokens (16384) exceeds max_num_seqs* max_model_len (8192). This may lead to unexpected behavior.
WARNING 06-29 08:28:54 [config.py:2303] max_num_batched_tokens (16384) exceeds max_num_seqs* max_model_len (8192). This may lead to unexpected behavior.
WARNING 06-29 08:28:54 [cuda.py:102] To see benefits of async output processing, enable CUDA graph. Since, enforce-eager is enabled, async output processor cannot be used
INFO 06-29 08:28:54 [core.py:459] Waiting for init message from front-end.
INFO 06-29 08:28:54 [core.py:69] Initializing a V1 LLM engine (v0.9.2.dev311+g4d3669368) with config: model='./Qwen2.5-VL-72B-Instruct-FP8-Dynamic', speculative_config=None, tokenizer='./Qwen2.5-VL-72B-Instruct-FP8-Dynamic', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config={}, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=4096, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=compressed-tensors, enforce_eager=True, kv_cache_dtype=auto, device_config=cuda, decoding_config=DecodingConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_backend=''), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=./Qwen2.5-VL-72B-Instruct-FP8-Dynamic, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=False, pooler_config=None, compilation_config={"level":0,"debug_dump_path":"","cache_dir":"","backend":"","custom_ops":[],"splitting_ops":[],"use_inductor":false,"compile_sizes":[],"inductor_compile_config":{"enable_auto_functionalized_v2":false},"inductor_passes":{},"use_cudagraph":true,"cudagraph_num_of_warmups":0,"cudagraph_capture_sizes":[],"cudagraph_copy_inputs":false,"full_cuda_graph":false,"max_capture_size":0,"local_cache_dir":null}
WARNING 06-29 08:28:55 [utils.py:2753] Methods determine_num_available_blocks,device_config,get_cache_block_size_bytes not implemented in <vllm.v1.worker.gpu_worker.Worker object at 0x7740b7e2ab90>
INFO 06-29 08:28:55 [parallel_state.py:1076] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0
WARNING 06-29 08:28:56 [profiling.py:266] The sequence length (4096) is smaller than the pre-defined wosrt-case total number of multimodal tokens (32768). This may cause certain multi-modal inputs to fail during inference. To avoid this, you should increase `max_model_len` or reduce `mm_counts`.
WARNING 06-29 08:28:56 [topk_topp_sampler.py:59] FlashInfer is not available. Falling back to the PyTorch-native implementation of top-p & top-k sampling. For the best performance, please install FlashInfer.
INFO 06-29 08:28:56 [gpu_model_runner.py:1725] Starting to load model ./Qwen2.5-VL-72B-Instruct-FP8-Dynamic...
INFO 06-29 08:28:56 [gpu_model_runner.py:1730] Loading model from scratch...
WARNING 06-29 08:28:56 [vision.py:91] Current `vllm-flash-attn` has a bug inside vision module, so we use xformers backend instead. You can run `pip install flash-attn` to use flash-attention backend.
WARNING 06-29 08:28:56 [config.py:2303] max_num_batched_tokens (16384) exceeds max_num_seqs* max_model_len (8192). This may lead to unexpected behavior.
INFO 06-29 08:28:56 [cuda.py:270] Using Flash Attention backend on V1 engine.
Loading safetensors checkpoint shards: 0% Completed | 0/16 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 6% Completed | 1/16 [00:00<00:09, 1.56it/s]
Loading safetensors checkpoint shards: 12% Completed | 2/16 [00:01<00:10, 1.34it/s]
Loading safetensors checkpoint shards: 19% Completed | 3/16 [00:02<00:09, 1.33it/s]
Loading safetensors checkpoint shards: 25% Completed | 4/16 [00:02<00:08, 1.34it/s]
Loading safetensors checkpoint shards: 31% Completed | 5/16 [00:03<00:08, 1.36it/s]
Loading safetensors checkpoint shards: 38% Completed | 6/16 [00:04<00:07, 1.38it/s]
Loading safetensors checkpoint shards: 44% Completed | 7/16 [00:05<00:06, 1.37it/s]
Loading safetensors checkpoint shards: 50% Completed | 8/16 [00:05<00:05, 1.36it/s]
Loading safetensors checkpoint shards: 56% Completed | 9/16 [00:06<00:05, 1.39it/s]
Loading safetensors checkpoint shards: 62% Completed | 10/16 [00:07<00:04, 1.39it/s]
Loading safetensors checkpoint shards: 69% Completed | 11/16 [00:07<00:03, 1.40it/s]
Loading safetensors checkpoint shards: 75% Completed | 12/16 [00:08<00:02, 1.41it/s]
Loading safetensors checkpoint shards: 81% Completed | 13/16 [00:09<00:02, 1.42it/s]
Loading safetensors checkpoint shards: 88% Completed | 14/16 [00:10<00:01, 1.40it/s]
Loading safetensors checkpoint shards: 94% Completed | 15/16 [00:10<00:00, 1.41it/s]
Loading safetensors checkpoint shards: 100% Completed | 16/16 [00:11<00:00, 1.48it/s]
Loading safetensors checkpoint shards: 100% Completed | 16/16 [00:11<00:00, 1.40it/s]
INFO 06-29 08:29:08 [default_loader.py:272] Loading weights took 11.51 seconds
INFO 06-29 08:29:08 [gpu_model_runner.py:1756] Model loading took 71.5736 GiB and 11.800011 seconds
INFO 06-29 08:29:08 [gpu_model_runner.py:2195] Encoder cache will be initialized with a budget of 16384 tokens, and profiled with 1 image items of the maximum feature size.
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
You have video processor config saved in `preprocessor.json` file which is deprecated. Video processor configs should be saved in their own `video_preprocessor.json` file. You can rename the file or load and save the processor back which renames it automatically. Loading from `preprocessor.json` will be removed in v5.0.
ERROR 06-29 08:29:12 [core.py:519] EngineCore failed to start.
ERROR 06-29 08:29:12 [core.py:519] Traceback (most recent call last):
ERROR 06-29 08:29:12 [core.py:519] File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/v1/engine/core.py", line 510, in run_engine_core
ERROR 06-29 08:29:12 [core.py:519] engine_core = EngineCoreProc(*args, **kwargs)
ERROR 06-29 08:29:12 [core.py:519] File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/v1/engine/core.py", line 394, in __init__
ERROR 06-29 08:29:12 [core.py:519] super().__init__(vllm_config, executor_class, log_stats,
ERROR 06-29 08:29:12 [core.py:519] File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/v1/engine/core.py", line 82, in __init__
ERROR 06-29 08:29:12 [core.py:519] self._initialize_kv_caches(vllm_config)
ERROR 06-29 08:29:12 [core.py:519] File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/v1/engine/core.py", line 142, in _initialize_kv_caches
ERROR 06-29 08:29:12 [core.py:519] available_gpu_memory = self.model_executor.determine_available_memory()
ERROR 06-29 08:29:12 [core.py:519] File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/v1/executor/abstract.py", line 76, in determine_available_memory
ERROR 06-29 08:29:12 [core.py:519] output = self.collective_rpc("determine_available_memory")
ERROR 06-29 08:29:12 [core.py:519] File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/executor/uniproc_executor.py", line 57, in collective_rpc
ERROR 06-29 08:29:12 [core.py:519] answer = run_method(self.driver_worker, method, args, kwargs)
ERROR 06-29 08:29:12 [core.py:519] File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/utils.py", line 2687, in run_method
ERROR 06-29 08:29:12 [core.py:519] return func(*args, **kwargs)
ERROR 06-29 08:29:12 [core.py:519] File "/home/jusheng/miniconda3/envs/qwenfp8/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
ERROR 06-29 08:29:12 [core.py:519] return func(*args, **kwargs)
ERROR 06-29 08:29:12 [core.py:519] File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/v1/worker/gpu_worker.py", line 210, in determine_available_memory
ERROR 06-29 08:29:12 [core.py:519] self.model_runner.profile_run()
ERROR 06-29 08:29:12 [core.py:519] File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/v1/worker/gpu_model_runner.py", line 2231, in profile_run
ERROR 06-29 08:29:12 [core.py:519] = self._dummy_run(self.max_num_tokens, is_profile=True)
ERROR 06-29 08:29:12 [core.py:519] File "/home/jusheng/miniconda3/envs/qwenfp8/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
ERROR 06-29 08:29:12 [core.py:519] return func(*args, **kwargs)
ERROR 06-29 08:29:12 [core.py:519] File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/v1/worker/gpu_model_runner.py", line 2012, in _dummy_run
ERROR 06-29 08:29:12 [core.py:519] outputs = model(
ERROR 06-29 08:29:12 [core.py:519] File "/home/jusheng/miniconda3/envs/qwenfp8/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
ERROR 06-29 08:29:12 [core.py:519] return self._call_impl(*args, **kwargs)
ERROR 06-29 08:29:12 [core.py:519] File "/home/jusheng/miniconda3/envs/qwenfp8/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
ERROR 06-29 08:29:12 [core.py:519] return forward_call(*args, **kwargs)
ERROR 06-29 08:29:12 [core.py:519] File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/model_executor/models/qwen2_5_vl.py", line 1136, in forward
ERROR 06-29 08:29:12 [core.py:519] hidden_states = self.language_model.model(
ERROR 06-29 08:29:12 [core.py:519] File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/compilation/decorators.py", line 173, in __call__
ERROR 06-29 08:29:12 [core.py:519] return self.forward(*args, **kwargs)
ERROR 06-29 08:29:12 [core.py:519] File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/model_executor/models/qwen2.py", line 354, in forward
ERROR 06-29 08:29:12 [core.py:519] hidden_states, residual = layer(
ERROR 06-29 08:29:12 [core.py:519] File "/home/jusheng/miniconda3/envs/qwenfp8/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
ERROR 06-29 08:29:12 [core.py:519] return self._call_impl(*args, **kwargs)
ERROR 06-29 08:29:12 [core.py:519] File "/home/jusheng/miniconda3/envs/qwenfp8/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
ERROR 06-29 08:29:12 [core.py:519] return forward_call(*args, **kwargs)
ERROR 06-29 08:29:12 [core.py:519] File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/model_executor/models/qwen2.py", line 253, in forward
ERROR 06-29 08:29:12 [core.py:519] hidden_states = self.self_attn(
ERROR 06-29 08:29:12 [core.py:519] File "/home/jusheng/miniconda3/envs/qwenfp8/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
ERROR 06-29 08:29:12 [core.py:519] return self._call_impl(*args, **kwargs)
ERROR 06-29 08:29:12 [core.py:519] File "/home/jusheng/miniconda3/envs/qwenfp8/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
ERROR 06-29 08:29:12 [core.py:519] return forward_call(*args, **kwargs)
ERROR 06-29 08:29:12 [core.py:519] File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/model_executor/models/qwen2.py", line 180, in forward
ERROR 06-29 08:29:12 [core.py:519] qkv, _ = self.qkv_proj(hidden_states)
ERROR 06-29 08:29:12 [core.py:519] File "/home/jusheng/miniconda3/envs/qwenfp8/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
ERROR 06-29 08:29:12 [core.py:519] return self._call_impl(*args, **kwargs)
ERROR 06-29 08:29:12 [core.py:519] File "/home/jusheng/miniconda3/envs/qwenfp8/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
ERROR 06-29 08:29:12 [core.py:519] return forward_call(*args, **kwargs)
ERROR 06-29 08:29:12 [core.py:519] File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/model_executor/layers/linear.py", line 487, in forward
ERROR 06-29 08:29:12 [core.py:519] output_parallel = self.quant_method.apply(self, input_, bias)
ERROR 06-29 08:29:12 [core.py:519] File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py", line 633, in apply
ERROR 06-29 08:29:12 [core.py:519] return scheme.apply_weights(layer, x, bias=bias)
ERROR 06-29 08:29:12 [core.py:519] File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py", line 145, in apply_weights
ERROR 06-29 08:29:12 [core.py:519] return self.fp8_linear.apply(input=x,
ERROR 06-29 08:29:12 [core.py:519] File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/model_executor/layers/quantization/utils/w8a8_utils.py", line 374, in apply
ERROR 06-29 08:29:12 [core.py:519] return w8a8_scaled_mm_func(qinput=qinput,
ERROR 06-29 08:29:12 [core.py:519] File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/model_executor/layers/quantization/utils/w8a8_utils.py", line 143, in cutlass_w8a8_scaled_mm
ERROR 06-29 08:29:12 [core.py:519] output = ops.cutlass_scaled_mm(qinput,
ERROR 06-29 08:29:12 [core.py:519] File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/_custom_ops.py", line 713, in cutlass_scaled_mm
ERROR 06-29 08:29:12 [core.py:519] torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
ERROR 06-29 08:29:12 [core.py:519] File "/home/jusheng/miniconda3/envs/qwenfp8/lib/python3.10/site-packages/torch/_ops.py", line 1158, in __call__
ERROR 06-29 08:29:12 [core.py:519] return self._op(*args, **(kwargs or {}))
ERROR 06-29 08:29:12 [core.py:519] RuntimeError: Expected a.dtype() == torch::kInt8 to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.)
Process EngineCore_0:
Traceback (most recent call last):
File "/home/jusheng/miniconda3/envs/qwenfp8/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
self.run()
File "/home/jusheng/miniconda3/envs/qwenfp8/lib/python3.10/multiprocessing/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/v1/engine/core.py", line 523, in run_engine_core
raise e
File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/v1/engine/core.py", line 510, in run_engine_core
engine_core = EngineCoreProc(*args, **kwargs)
File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/v1/engine/core.py", line 394, in __init__
super().__init__(vllm_config, executor_class, log_stats,
File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/v1/engine/core.py", line 82, in __init__
self._initialize_kv_caches(vllm_config)
File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/v1/engine/core.py", line 142, in _initialize_kv_caches
available_gpu_memory = self.model_executor.determine_available_memory()
File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/v1/executor/abstract.py", line 76, in determine_available_memory
output = self.collective_rpc("determine_available_memory")
File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/executor/uniproc_executor.py", line 57, in collective_rpc
answer = run_method(self.driver_worker, method, args, kwargs)
File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/utils.py", line 2687, in run_method
return func(*args, **kwargs)
File "/home/jusheng/miniconda3/envs/qwenfp8/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/v1/worker/gpu_worker.py", line 210, in determine_available_memory
self.model_runner.profile_run()
File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/v1/worker/gpu_model_runner.py", line 2231, in profile_run
= self._dummy_run(self.max_num_tokens, is_profile=True)
File "/home/jusheng/miniconda3/envs/qwenfp8/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/v1/worker/gpu_model_runner.py", line 2012, in _dummy_run
outputs = model(
File "/home/jusheng/miniconda3/envs/qwenfp8/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/jusheng/miniconda3/envs/qwenfp8/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/model_executor/models/qwen2_5_vl.py", line 1136, in forward
hidden_states = self.language_model.model(
File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/compilation/decorators.py", line 173, in __call__
return self.forward(*args, **kwargs)
File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/model_executor/models/qwen2.py", line 354, in forward
hidden_states, residual = layer(
File "/home/jusheng/miniconda3/envs/qwenfp8/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/jusheng/miniconda3/envs/qwenfp8/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/model_executor/models/qwen2.py", line 253, in forward
hidden_states = self.self_attn(
File "/home/jusheng/miniconda3/envs/qwenfp8/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/jusheng/miniconda3/envs/qwenfp8/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/model_executor/models/qwen2.py", line 180, in forward
qkv, _ = self.qkv_proj(hidden_states)
File "/home/jusheng/miniconda3/envs/qwenfp8/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/jusheng/miniconda3/envs/qwenfp8/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/model_executor/layers/linear.py", line 487, in forward
output_parallel = self.quant_method.apply(self, input_, bias)
File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py", line 633, in apply
return scheme.apply_weights(layer, x, bias=bias)
File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py", line 145, in apply_weights
return self.fp8_linear.apply(input=x,
File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/model_executor/layers/quantization/utils/w8a8_utils.py", line 374, in apply
return w8a8_scaled_mm_func(qinput=qinput,
File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/model_executor/layers/quantization/utils/w8a8_utils.py", line 143, in cutlass_w8a8_scaled_mm
output = ops.cutlass_scaled_mm(qinput,
File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/_custom_ops.py", line 713, in cutlass_scaled_mm
torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
File "/home/jusheng/miniconda3/envs/qwenfp8/lib/python3.10/site-packages/torch/_ops.py", line 1158, in __call__
return self._op(*args, **(kwargs or {}))
RuntimeError: Expected a.dtype() == torch::kInt8 to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.)
Traceback (most recent call last):
File "/home/jusheng/jusheng_files/Qwen2.5VL_FP_infer/test.py", line 14, in <module>
llm = LLM(
File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/entrypoints/llm.py", line 263, in __init__
self.llm_engine = LLMEngine.from_engine_args(
File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/engine/llm_engine.py", line 501, in from_engine_args
return engine_cls.from_vllm_config(
File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/v1/engine/llm_engine.py", line 124, in from_vllm_config
return cls(vllm_config=vllm_config,
File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/v1/engine/llm_engine.py", line 101, in __init__
self.engine_core = EngineCoreClient.make_client(
File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/v1/engine/core_client.py", line 75, in make_client
return SyncMPClient(vllm_config, executor_class, log_stats)
File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/v1/engine/core_client.py", line 572, in __init__
super().__init__(
File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/v1/engine/core_client.py", line 433, in __init__
self._init_engines_direct(vllm_config, local_only,
File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/v1/engine/core_client.py", line 502, in _init_engines_direct
self._wait_for_engine_startup(handshake_socket, input_address,
File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/v1/engine/core_client.py", line 522, in _wait_for_engine_startup
wait_for_engine_startup(
File "/mnt/SSD_8T/jusheng/jusheng_files/vllm/vllm/v1/utils.py", line 494, in wait_for_engine_startup
raise RuntimeError("Engine core initialization failed. "
RuntimeError: Engine core initialization failed. See root cause above. Failed core proc(s): {}
The Core issue is **RuntimeError: Expected a.dtype() == torch::kInt8 to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.)**.
I am now trying to use INT8 to avoid this fp8 quesion.
### Before submitting a new issue...
- [x] Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions.