Skip to content

Commit c3ea878

Browse files
Isotr0pypatrickvonplaten
authored andcommitted
[v1] Re-add fp32 support to v1 engine through FlexAttention (vllm-project#19754)
Signed-off-by: Isotr0py <[email protected]> Signed-off-by: Isotr0py <[email protected]> Signed-off-by: Patrick von Platen <[email protected]>
1 parent 74b5990 commit c3ea878

File tree

8 files changed

+59
-12
lines changed

8 files changed

+59
-12
lines changed

.github/workflows/lint-and-deploy.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ jobs:
6868
export AWS_ACCESS_KEY_ID=minioadmin
6969
export AWS_SECRET_ACCESS_KEY=minioadmin
7070
sleep 30 && kubectl -n ns-vllm logs -f "$(kubectl -n ns-vllm get pods | awk '/deployment/ {print $1;exit}')" &
71-
helm install --wait --wait-for-jobs --timeout 5m0s --debug --create-namespace --namespace=ns-vllm test-vllm examples/online_serving/chart-helm -f examples/online_serving/chart-helm/values.yaml --set secrets.s3endpoint=http://minio:9000 --set secrets.s3bucketname=testbucket --set secrets.s3accesskeyid=$AWS_ACCESS_KEY_ID --set secrets.s3accesskey=$AWS_SECRET_ACCESS_KEY --set resources.requests.cpu=1 --set resources.requests.memory=4Gi --set resources.limits.cpu=2 --set resources.limits.memory=5Gi --set image.env[0].name=VLLM_CPU_KVCACHE_SPACE --set image.env[1].name=VLLM_LOGGING_LEVEL --set-string image.env[0].value="1" --set-string image.env[1].value="DEBUG" --set-string extraInit.s3modelpath="opt-125m/" --set-string 'resources.limits.nvidia\.com/gpu=0' --set-string 'resources.requests.nvidia\.com/gpu=0' --set-string image.repository="vllm-cpu-env"
71+
helm install --wait --wait-for-jobs --timeout 5m0s --debug --create-namespace --namespace=ns-vllm test-vllm examples/online_serving/chart-helm -f examples/online_serving/chart-helm/values.yaml --set secrets.s3endpoint=http://minio:9000 --set secrets.s3bucketname=testbucket --set secrets.s3accesskeyid=$AWS_ACCESS_KEY_ID --set secrets.s3accesskey=$AWS_SECRET_ACCESS_KEY --set resources.requests.cpu=1 --set resources.requests.memory=4Gi --set resources.limits.cpu=2 --set resources.limits.memory=5Gi --set image.env[0].name=VLLM_CPU_KVCACHE_SPACE --set image.env[1].name=VLLM_LOGGING_LEVEL --set image.env[2].name=VLLM_CPU_CI_ENV --set-string image.env[0].value="1" --set-string image.env[1].value="DEBUG" --set-string image.env[2].value="1" --set-string extraInit.s3modelpath="opt-125m/" --set-string 'resources.limits.nvidia\.com/gpu=0' --set-string 'resources.requests.nvidia\.com/gpu=0' --set-string image.repository="vllm-cpu-env"
7272
7373
- name: curl test
7474
run: |

tests/kernels/attention/test_attention_selector.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,34 @@ def test_env(
181181
assert backend.get_name() == expected
182182

183183

184+
@pytest.mark.parametrize("device", ["cpu", "cuda"])
185+
@pytest.mark.parametrize("use_v1", [True, False])
186+
def test_fp32_fallback(
187+
device: str,
188+
use_v1: bool,
189+
monkeypatch: pytest.MonkeyPatch,
190+
):
191+
"""Test attention backend selection with fp32."""
192+
with monkeypatch.context() as m:
193+
m.setenv("VLLM_USE_V1", "1" if use_v1 else "0")
194+
195+
if device == "cpu":
196+
with patch("vllm.attention.selector.current_platform",
197+
CpuPlatform()):
198+
backend = get_attn_backend(16, torch.float32, torch.float32,
199+
16, False)
200+
assert (backend.get_name() == "TORCH_SDPA_VLLM_V1"
201+
if use_v1 else "TORCH_SDPA")
202+
203+
elif device == "cuda":
204+
with patch("vllm.attention.selector.current_platform",
205+
CudaPlatform()):
206+
backend = get_attn_backend(16, torch.float32, torch.float32,
207+
16, False)
208+
assert (backend.get_name() == "FLEX_ATTENTION"
209+
if use_v1 else "XFORMERS")
210+
211+
184212
def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
185213
"""Test FlashAttn validation."""
186214
# TODO: When testing for v1, pipe in `use_v1` as an argument to

tests/v1/worker/test_gpu_model_runner.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,7 @@ def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2):
450450

451451

452452
def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
453+
torch.set_default_dtype(torch.float16)
453454
layer_0 = "model.layers.0.self_attn.attn"
454455
layer_1 = "model.layers.1.self_attn.attn"
455456
error_msg = f"{layer_1} must come before the current layer"
@@ -478,6 +479,7 @@ def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
478479

479480

480481
def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():
482+
torch.set_default_dtype(torch.float16)
481483
layer_0 = "model.layers.0.self_attn.attn"
482484
layer_1 = "model.layers.1.self_attn.attn"
483485
invalid_layer = "model.layers.0.cross_attn.attn"
@@ -506,6 +508,7 @@ def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():
506508

507509

508510
def test_init_kv_cache_with_kv_sharing_target_same_as_current():
511+
torch.set_default_dtype(torch.float16)
509512
layer_0 = "model.layers.0.self_attn.attn"
510513
layer_1 = "model.layers.1.self_attn.attn"
511514
error_msg = f"{layer_1} cannot be the same as the current layer"
@@ -534,6 +537,7 @@ def test_init_kv_cache_with_kv_sharing_target_same_as_current():
534537

535538

536539
def test_init_kv_cache_without_kv_sharing():
540+
torch.set_default_dtype(torch.float16)
537541
layer_0 = "model.layers.0.self_attn.attn"
538542
layer_1 = "model.layers.1.self_attn.attn"
539543
vllm_config = get_vllm_config()
@@ -601,6 +605,7 @@ def test_init_kv_cache_without_kv_sharing():
601605

602606

603607
def test_init_kv_cache_with_kv_sharing_valid():
608+
torch.set_default_dtype(torch.float16)
604609
layer_0 = "model.layers.0.self_attn.attn"
605610
layer_1 = "model.layers.1.self_attn.attn"
606611
vllm_config = get_vllm_config()

vllm/engine/arg_utils.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1393,13 +1393,6 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
13931393
recommend_to_remove=False)
13941394
return False
13951395

1396-
# Only Fp16 and Bf16 dtypes since we only support FA.
1397-
V1_SUPPORTED_DTYPES = [torch.bfloat16, torch.float16]
1398-
if model_config.dtype not in V1_SUPPORTED_DTYPES:
1399-
_raise_or_fallback(feature_name=f"--dtype {model_config.dtype}",
1400-
recommend_to_remove=False)
1401-
return False
1402-
14031396
# No Mamba or Encoder-Decoder so far.
14041397
if not model_config.is_v1_compatible:
14051398
_raise_or_fallback(feature_name=model_config.architectures,

vllm/model_executor/model_loader/tensorizer_loader.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,12 @@ def load_model(self, vllm_config: VllmConfig,
104104

105105
if is_vllm_tensorized(self.tensorizer_config):
106106
tensorizer_config = self._patch_tensorizer_config(model_config)
107-
model = init_tensorizer_model(tensorizer_config=tensorizer_config,
108-
vllm_config=vllm_config)
107+
device_config = vllm_config.device_config
108+
with set_default_torch_dtype(model_config.dtype):
109+
with torch.device(device_config.device):
110+
model = init_tensorizer_model(
111+
tensorizer_config=tensorizer_config,
112+
vllm_config=vllm_config)
109113
self.load_weights(model, model_config)
110114
return model
111115
return self._load_model_serialized_cpu(vllm_config=vllm_config)

vllm/platforms/cuda.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,10 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
251251

252252
# Default backends for V1 engine
253253
# Prefer FlashInfer for Blackwell GPUs if installed
254+
if dtype not in (torch.float16, torch.bfloat16):
255+
logger.info_once(
256+
f"Using FlexAttenion backend for {dtype} on V1 engine.")
257+
return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
254258
if cls.is_device_capability(100):
255259
try:
256260
import flashinfer # noqa: F401

vllm/v1/attention/backends/flex_attention.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,13 @@ def forward(
463463
query = query[:, :, :num_actual_tokens, :]
464464
# Doesn't work for now -> constraint violation
465465
# torch._dynamo.try_mark_dynamic(query, 2)
466+
467+
# default M=64, N=64 may run out of shared memory on
468+
# some GPUs with fp32, so we use smaller M and N.
469+
extra_kernel_options = {
470+
"BLOCK_M": 32,
471+
"BLOCK_N": 32
472+
} if query.dtype == torch.float32 else {}
466473
out = flex_attention_compiled(
467474
query,
468475
key_cache,
@@ -471,7 +478,10 @@ def forward(
471478
attn_metadata.block_mask,
472479
self.scale,
473480
enable_gqa=enable_gqa,
474-
kernel_options={"FORCE_USE_FLEX_ATTENTION": True},
481+
kernel_options={
482+
"FORCE_USE_FLEX_ATTENTION": True,
483+
**extra_kernel_options
484+
},
475485
)
476486

477487
# Flex doesn't have an out variant today, rely on epilogue fusion

vllm/v1/sample/ops/topk_topp_sampler.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,10 @@ def forward_cuda(
101101
"per-request generators. Falling back to "
102102
"PyTorch-native implementation.")
103103
return self.forward_native(logits, generators, k, p)
104-
return flashinfer_sample(logits, k, p, generators)
104+
# flashinfer sampling functions expect contiguous logits.
105+
# In flex_attn/triton_attn fp32 inference, logits can be non-contiguous
106+
# because of slicing operation in logits_processor.
107+
return flashinfer_sample(logits.contiguous(), k, p, generators)
105108

106109
def forward_tpu(
107110
self,

0 commit comments

Comments
 (0)