Skip to content

[MISC] Use non-blocking transfer in prepare_input #7172

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 12 additions & 15 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.utils import make_tensor_with_pad
from vllm.utils import async_tensor_h2d, make_tensor_with_pad

if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUBuilder
Expand Down Expand Up @@ -310,7 +310,8 @@ def build(self, seq_lens: List[int], query_lens: List[int],
for i, block_table in enumerate(self.block_tables):
if block_table:
input_block_tables[i, :len(block_table)] = block_table
block_tables = torch.tensor(input_block_tables, device=device)
block_tables = torch.tensor(input_block_tables).to(
device=device, non_blocking=True)
else:
block_tables = make_tensor_with_pad(
self.block_tables,
Expand All @@ -320,15 +321,15 @@ def build(self, seq_lens: List[int], query_lens: List[int],
)
assert max_query_len > 0, ("query_lens: {}".format(query_lens))

context_lens_tensor = torch.tensor(self.context_lens,
dtype=torch.int,
device=device)
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=device)
query_lens_tensor = torch.tensor(query_lens,
dtype=torch.long,
device=device)
assert device is not None
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
device, self.runner.pin_memory)
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
self.runner.pin_memory)
query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device,
self.runner.pin_memory)
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
device, self.runner.pin_memory)
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=device)
Expand All @@ -344,10 +345,6 @@ def build(self, seq_lens: List[int], query_lens: List[int],
dtype=query_start_loc.dtype,
out=query_start_loc[1:])

slot_mapping_tensor = torch.tensor(self.slot_mapping,
dtype=torch.long,
device=device)

return FlashAttentionMetadata(
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
Expand Down
23 changes: 11 additions & 12 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.utils import get_kv_cache_torch_dtype, make_tensor_with_pad
from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
make_tensor_with_pad)

if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUBuilder
Expand Down Expand Up @@ -356,7 +357,8 @@ def build(self, seq_lens: List[int], query_lens: List[int],
for i, block_table in enumerate(self.block_tables):
if block_table:
input_block_tables[i, :len(block_table)] = block_table
block_tables = torch.tensor(input_block_tables, device=device)
block_tables = torch.tensor(input_block_tables).to(
device, non_blocking=True)

last_paged_kv_indptr = self.paged_kv_indptr[-1]
self.paged_kv_indptr.extend([last_paged_kv_indptr] *
Expand All @@ -371,12 +373,13 @@ def build(self, seq_lens: List[int], query_lens: List[int],
)
assert max_query_len > 0, ("query_lens: {}".format(query_lens))

seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=device)
query_lens_tensor = torch.tensor(query_lens,
dtype=torch.long,
device=device)
assert device is not None
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
self.runner.pin_memory)
query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device,
self.runner.pin_memory)
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
device, self.runner.pin_memory)
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=device)
Expand All @@ -392,10 +395,6 @@ def build(self, seq_lens: List[int], query_lens: List[int],
dtype=query_start_loc.dtype,
out=query_start_loc[1:])

slot_mapping_tensor = torch.tensor(self.slot_mapping,
dtype=torch.long,
device=device)

if len(self.paged_kv_indptr) > 0:
paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices,
device="cpu",
Expand Down
27 changes: 12 additions & 15 deletions vllm/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch

from vllm.attention import AttentionMetadata, AttentionMetadataBuilder
from vllm.utils import make_tensor_with_pad
from vllm.utils import async_tensor_h2d, make_tensor_with_pad

# Error string(s) for encoder/decoder
# unsupported attention scenarios
Expand Down Expand Up @@ -181,7 +181,8 @@ def build(self, seq_lens: List[int], query_lens: List[int],
for i, block_table in enumerate(self.block_tables):
if block_table:
input_block_tables[i, :len(block_table)] = block_table
block_tables = torch.tensor(input_block_tables, device=device)
block_tables = torch.tensor(input_block_tables).to(
device, non_blocking=True)
else:
block_tables = make_tensor_with_pad(
self.block_tables,
Expand All @@ -191,15 +192,15 @@ def build(self, seq_lens: List[int], query_lens: List[int],
)
assert max_query_len > 0, "query_lens: {}".format(query_lens)

context_lens_tensor = torch.tensor(self.context_lens,
dtype=torch.int,
device=device)
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=device)
query_lens_tensor = torch.tensor(query_lens,
dtype=torch.long,
device=device)
assert device is not None
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
device, self.runner.pin_memory)
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
self.runner.pin_memory)
query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device,
self.runner.pin_memory)
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
device, self.runner.pin_memory)
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=device)
Expand All @@ -215,10 +216,6 @@ def build(self, seq_lens: List[int], query_lens: List[int],
dtype=query_start_loc.dtype,
out=query_start_loc[1:])

slot_mapping_tensor = torch.tensor(self.slot_mapping,
dtype=torch.long,
device=device)

return self._metadata_cls( # type: ignore
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
Expand Down
15 changes: 8 additions & 7 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
from vllm.utils import (CudaMemoryProfiler, flatten_2d_lists,
from vllm.utils import (CudaMemoryProfiler, async_tensor_h2d, flatten_2d_lists,
get_kv_cache_torch_dtype, is_hip,
is_pin_memory_available)
from vllm.worker.model_runner_base import (
Expand Down Expand Up @@ -549,12 +549,13 @@ def build(self) -> ModelInputForGPU:
# Tokens and positions.
input_tokens.extend([0] * cuda_graph_pad_size)
input_positions.extend([0] * cuda_graph_pad_size)
input_tokens_tensor = torch.tensor(input_tokens,
dtype=torch.long,
device=self.runner.device)
input_positions_tensor = torch.tensor(input_positions,
dtype=torch.long,
device=self.runner.device)
assert self.runner.device is not None
input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long,
self.runner.device,
self.runner.pin_memory)
input_positions_tensor = async_tensor_h2d(input_positions, torch.long,
self.runner.device,
self.runner.pin_memory)

# Sequence and query lengths.
seq_lens.extend([1] * cuda_graph_pad_size)
Expand Down
Loading