-
Notifications
You must be signed in to change notification settings - Fork 609
Open
Description
Checklist
- 1. I have searched related issues but cannot get the expected help.
- 2. The bug has not been fixed in the latest version.
- 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
Describe the bug
When processing multimodal chat messages with multiple text blocks, the current implementation only extracts the first text block (msg['content'][0]['text']
), discarding all subsequent text blocks. This leads to incomplete prompt processing and loss of user input.
Reproduction
This issue was discovered when Claude Code sends messages with multiple text blocks in a single user message.
- Start an LMDeploy server with Qwen3-32B:
lmdeploy serve api_server Qwen/Qwen3-32B --tool-call-parser qwen3
- Claude Code sends a message with 3 separate text blocks:
from openai import OpenAI
client = OpenAI(
api_key='EMPTY',
base_url='http://localhost:23333/v1'
)
model = 'Qwen/Qwen3-32B'
# This mimics what Claude Code sends when user provides complex instructions
messages = [{
'role': 'user',
'content': [
{
'type': 'text',
'text': 'This is a test.'
},
{
'type': 'text',
'text': 'Say 1..2..3'
},
]
}]
response = client.chat.completions.create(
model=model,
messages=messages
)
Root Cause
In lmdeploy/serve/async_engine.py
, the code only extracts the first element:
dict(role=msg['role'], content=msg['content'][0]['text'])
This should iterate over all text blocks and merge them with '\n'.join()
.
Reference
vLLM's implementation correctly merges all text blocks:
text_prompt = "\n".join(texts)
See: https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/chat_utils.py
Environment
/bin/sh: 1: /usr/local/cuda/bin/nvcc: not found
/bin/sh: 1: gcc: not found
sys.platform: linux
Python: 3.11.13 (main, Jun 5 2025, 13:12:00) [GCC 11.2.0]
CUDA available: True
MUSA available: False
numpy_random_seed: 2147483648
GPU 0,1,2,3: Tesla V100-SXM2-16GB
CUDA_HOME: /usr/local/cuda
NVCC: Not Available
GCC: n/a
PyTorch: 2.8.0+cu128
PyTorch compiling details: PyTorch built with:
- GCC 13.3
- C++ Version: 201703
- Intel(R) oneAPI Math Kernel Library Version 2024.2-Product Build 20240605 for Intel(R) 64 architecture applications
- Intel(R) MKL-DNN v3.7.1 (Git Hash 8d263e693366ef8db40acc569cc7d8edf644556d)
- OpenMP 201511 (a.k.a. OpenMP 4.5)
- LAPACK is enabled (usually provided by MKL)
- NNPACK is enabled
- CPU capability usage: AVX2
- CUDA Runtime 12.8
- NVCC architecture flags: -gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_90,code=sm_90;-gencode;arch=compute_100,code=sm_100;-gencode;arch=compute_120,code=sm_120
- CuDNN 91.0.2 (built against CUDA 12.9)
- Built with CuDNN 90.8
- Magma 2.6.1
- Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, COMMIT_SHA=a1cb3cc05d46d198467bebbb6e8fba50a325d4e7, CUDA_VERSION=12.8, CUDNN_VERSION=9.8.0, CXX_COMPILER=/opt/rh/gcc-toolset-13/root/usr/bin/c++, CXX_FLAGS= -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DLIBKINETO_NOXPUPTI=ON -DUSE_FBGEMM -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -DC10_NODEPRECATED -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=range-loop-construct -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-unknown-pragmas -Wno-unused-parameter -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wsuggest-override -Wno-psabi -Wno-error=old-style-cast -faligned-new -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-dangling-reference -Wno-error=dangling-reference -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, TORCH_VERSION=2.8.0, USE_CUDA=ON, USE_CUDNN=ON, USE_CUSPARSELT=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_GLOO=ON, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=1, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, USE_ROCM_KERNEL_ASSERT=OFF, USE_XCCL=OFF, USE_XPU=OFF,
TorchVision: 0.23.0+cu128
LMDeploy: 0.10.1+unknown
transformers: 4.56.2
fastapi: 0.117.1
pydantic: 2.11.9
triton: 3.4.0
NVIDIA Topology:
GPU0 GPU1 GPU2 GPU3 CPU Affinity NUMA Affinity GPU NUMA ID
GPU0 X NV2 NV2 NV2 0,2,4,6,8,10 0 N/A
GPU1 NV2 X NV2 NV2 0,2,4,6,8,10 0 N/A
GPU2 NV2 NV2 X NV2 0,2,4,6,8,10 0 N/A
GPU3 NV2 NV2 NV2 X 0,2,4,6,8,10 0 N/A
Legend:
X = Self
SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
PIX = Connection traversing at most a single PCIe bridge
NV# = Connection traversing a bonded set of # NVLinks
Error traceback
# No errors, but the log shows the user's message was dropped:
2025-10-11 14:53:48,450 - lmdeploy - INFO - logger.py:45 - session=9, adapter_name=None, input_tokens=13, gen_config=GenerationConfig(n=1, max_new_tokens=None, do_sample=True, top_p=1.0, top_k=40, min_p=0.0, temperature=0.7, repetition_penalty=1.0, ignore_eos=False, random_seed=3968898985380990939, stop_words=None, bad_words=None, stop_token_ids=[151643, 151645], bad_token_ids=None, min_new_tokens=None, skip_special_tokens=True, spaces_between_special_tokens=True, logprobs=None, response_format=None, logits_processors=None, output_logits=None, output_last_hidden_state=None, include_stop_str_in_output=False, with_cache=False, preserve_cache=False, migration_request=None), prompt='<|im_start|>user\nThis is a test.<|im_end|>\n<|im_start|>assistant\n', prompt_token_id=[151644, 872, 198, 1986, 374, 264, 1273, 13, 151645, 198, 151644, 77091, 198]
Metadata
Metadata
Assignees
Labels
No labels