Skip to content

[v1] Re-add fp32 support to v1 engine through FlexAttention #19754

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Jul 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/lint-and-deploy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ jobs:
export AWS_ACCESS_KEY_ID=minioadmin
export AWS_SECRET_ACCESS_KEY=minioadmin
sleep 30 && kubectl -n ns-vllm logs -f "$(kubectl -n ns-vllm get pods | awk '/deployment/ {print $1;exit}')" &
helm install --wait --wait-for-jobs --timeout 5m0s --debug --create-namespace --namespace=ns-vllm test-vllm examples/online_serving/chart-helm -f examples/online_serving/chart-helm/values.yaml --set secrets.s3endpoint=http://minio:9000 --set secrets.s3bucketname=testbucket --set secrets.s3accesskeyid=$AWS_ACCESS_KEY_ID --set secrets.s3accesskey=$AWS_SECRET_ACCESS_KEY --set resources.requests.cpu=1 --set resources.requests.memory=4Gi --set resources.limits.cpu=2 --set resources.limits.memory=5Gi --set image.env[0].name=VLLM_CPU_KVCACHE_SPACE --set image.env[1].name=VLLM_LOGGING_LEVEL --set-string image.env[0].value="1" --set-string image.env[1].value="DEBUG" --set-string extraInit.s3modelpath="opt-125m/" --set-string 'resources.limits.nvidia\.com/gpu=0' --set-string 'resources.requests.nvidia\.com/gpu=0' --set-string image.repository="vllm-cpu-env"
helm install --wait --wait-for-jobs --timeout 5m0s --debug --create-namespace --namespace=ns-vllm test-vllm examples/online_serving/chart-helm -f examples/online_serving/chart-helm/values.yaml --set secrets.s3endpoint=http://minio:9000 --set secrets.s3bucketname=testbucket --set secrets.s3accesskeyid=$AWS_ACCESS_KEY_ID --set secrets.s3accesskey=$AWS_SECRET_ACCESS_KEY --set resources.requests.cpu=1 --set resources.requests.memory=4Gi --set resources.limits.cpu=2 --set resources.limits.memory=5Gi --set image.env[0].name=VLLM_CPU_KVCACHE_SPACE --set image.env[1].name=VLLM_LOGGING_LEVEL --set image.env[2].name=VLLM_CPU_CI_ENV --set-string image.env[0].value="1" --set-string image.env[1].value="DEBUG" --set-string image.env[2].value="1" --set-string extraInit.s3modelpath="opt-125m/" --set-string 'resources.limits.nvidia\.com/gpu=0' --set-string 'resources.requests.nvidia\.com/gpu=0' --set-string image.repository="vllm-cpu-env"

- name: curl test
run: |
Expand Down
28 changes: 28 additions & 0 deletions tests/kernels/attention/test_attention_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,34 @@ def test_env(
assert backend.get_name() == expected


@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("use_v1", [True, False])
def test_fp32_fallback(
device: str,
use_v1: bool,
monkeypatch: pytest.MonkeyPatch,
):
"""Test attention backend selection with fp32."""
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1" if use_v1 else "0")

if device == "cpu":
with patch("vllm.attention.selector.current_platform",
CpuPlatform()):
backend = get_attn_backend(16, torch.float32, torch.float32,
16, False)
assert (backend.get_name() == "TORCH_SDPA_VLLM_V1"
if use_v1 else "TORCH_SDPA")

elif device == "cuda":
with patch("vllm.attention.selector.current_platform",
CudaPlatform()):
backend = get_attn_backend(16, torch.float32, torch.float32,
16, False)
assert (backend.get_name() == "FLEX_ATTENTION"
if use_v1 else "XFORMERS")


def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
"""Test FlashAttn validation."""
# TODO: When testing for v1, pipe in `use_v1` as an argument to
Expand Down
5 changes: 5 additions & 0 deletions tests/v1/worker/test_gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2):


def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
error_msg = f"{layer_1} must come before the current layer"
Expand Down Expand Up @@ -478,6 +479,7 @@ def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():


def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():
torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
invalid_layer = "model.layers.0.cross_attn.attn"
Expand Down Expand Up @@ -506,6 +508,7 @@ def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():


def test_init_kv_cache_with_kv_sharing_target_same_as_current():
torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
error_msg = f"{layer_1} cannot be the same as the current layer"
Expand Down Expand Up @@ -534,6 +537,7 @@ def test_init_kv_cache_with_kv_sharing_target_same_as_current():


def test_init_kv_cache_without_kv_sharing():
torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
vllm_config = get_vllm_config()
Expand Down Expand Up @@ -601,6 +605,7 @@ def test_init_kv_cache_without_kv_sharing():


def test_init_kv_cache_with_kv_sharing_valid():
torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
vllm_config = get_vllm_config()
Expand Down
7 changes: 0 additions & 7 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1393,13 +1393,6 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
recommend_to_remove=False)
return False

# Only Fp16 and Bf16 dtypes since we only support FA.
V1_SUPPORTED_DTYPES = [torch.bfloat16, torch.float16]
if model_config.dtype not in V1_SUPPORTED_DTYPES:
_raise_or_fallback(feature_name=f"--dtype {model_config.dtype}",
recommend_to_remove=False)
return False

# No Mamba or Encoder-Decoder so far.
if not model_config.is_v1_compatible:
_raise_or_fallback(feature_name=model_config.architectures,
Expand Down
8 changes: 6 additions & 2 deletions vllm/model_executor/model_loader/tensorizer_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,12 @@ def load_model(self, vllm_config: VllmConfig,

if is_vllm_tensorized(self.tensorizer_config):
tensorizer_config = self._patch_tensorizer_config(model_config)
model = init_tensorizer_model(tensorizer_config=tensorizer_config,
vllm_config=vllm_config)
device_config = vllm_config.device_config
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = init_tensorizer_model(
tensorizer_config=tensorizer_config,
vllm_config=vllm_config)
self.load_weights(model, model_config)
return model
return self._load_model_serialized_cpu(vllm_config=vllm_config)
Expand Down
4 changes: 4 additions & 0 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,10 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,

# Default backends for V1 engine
# Prefer FlashInfer for Blackwell GPUs if installed
if dtype not in (torch.float16, torch.bfloat16):
logger.info_once(
f"Using FlexAttenion backend for {dtype} on V1 engine.")
return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
if cls.is_device_capability(100):
try:
import flashinfer # noqa: F401
Expand Down
12 changes: 11 additions & 1 deletion vllm/v1/attention/backends/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,13 @@ def forward(
query = query[:, :, :num_actual_tokens, :]
# Doesn't work for now -> constraint violation
# torch._dynamo.try_mark_dynamic(query, 2)

# default M=64, N=64 may run out of shared memory on
# some GPUs with fp32, so we use smaller M and N.
extra_kernel_options = {
"BLOCK_M": 32,
"BLOCK_N": 32
} if query.dtype == torch.float32 else {}
out = flex_attention_compiled(
query,
key_cache,
Expand All @@ -471,7 +478,10 @@ def forward(
attn_metadata.block_mask,
self.scale,
enable_gqa=enable_gqa,
kernel_options={"FORCE_USE_FLEX_ATTENTION": True},
kernel_options={
"FORCE_USE_FLEX_ATTENTION": True,
**extra_kernel_options
},
)

# Flex doesn't have an out variant today, rely on epilogue fusion
Expand Down
5 changes: 4 additions & 1 deletion vllm/v1/sample/ops/topk_topp_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,10 @@ def forward_cuda(
"per-request generators. Falling back to "
"PyTorch-native implementation.")
return self.forward_native(logits, generators, k, p)
return flashinfer_sample(logits, k, p, generators)
# flashinfer sampling functions expect contiguous logits.
# In flex_attn/triton_attn fp32 inference, logits can be non-contiguous
# because of slicing operation in logits_processor.
return flashinfer_sample(logits.contiguous(), k, p, generators)

def forward_tpu(
self,
Expand Down