Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion examples/offline_inference/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@

# Set `enforce_eager=True` to avoid ahead-of-time compilation.
# In real workloads, `enforace_eager` should be `False`.
llm = LLM(model="google/gemma-2b", enforce_eager=True)
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct",
max_num_batched_tokens=64,
max_num_seqs=4)
outputs = llm.generate(prompts, sampling_params)
for output, answer in zip(outputs, answers):
prompt = output.prompt
Expand Down
47 changes: 38 additions & 9 deletions vllm/v1/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
self.query_start_loc_np[0] = 0
np.cumsum(num_scheduled_tokens_per_req,
out=self.query_start_loc_np[1:num_reqs + 1])
self.query_start_loc_np[num_reqs + 1:] = 1

self.seq_lens_np[:num_reqs] = (
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
Expand Down Expand Up @@ -441,7 +442,10 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
# partial request, we do so for simplicity. We will ignore the sampled
# token from the partial request.
# TODO: Support prompt logprobs.
logits_indices = query_start_loc[1:] - 1
padded_num_reqs = _get_padded_num_reqs_with_upper_limit(
num_reqs, self.max_num_reqs)
logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1
logits_indices = logits_indices.to(self.device)
return attn_metadata, logits_indices

def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
Expand Down Expand Up @@ -551,7 +555,6 @@ def execute_model(

# Prepare inputs
attn_metadata, logits_indices = self._prepare_inputs(scheduler_output)
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens

if self.is_multimodal_model:
# NOTE(woosuk): To unify token ids and soft tokens (vision
Expand Down Expand Up @@ -579,12 +582,10 @@ def execute_model(
kv_caches=self.kv_caches,
inputs_embeds=inputs_embeds,
)
hidden_states = hidden_states[:total_num_scheduled_tokens]
num_reqs = self.input_batch.num_reqs
logits_indices = logits_indices[:num_reqs]
hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(hidden_states, None)
selected_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
selected_token_ids = self.model.compute_logits(hidden_states,
logits_indices, None)
selected_token_ids = selected_token_ids.cpu()[:num_reqs]

# Then, let's update the cache state.
request_seq_lens: list[tuple[int, CachedRequestState, int]] = []
Expand Down Expand Up @@ -726,12 +727,31 @@ def _dummy_run(

with set_forward_context(attn_metadata, self.vllm_config, 0):
assert self.model is not None
self.model(
hidden_states = self.model(
input_ids=input_ids,
positions=position_ids,
kv_caches=kv_caches,
inputs_embeds=inputs_embeds,
)
num_reqs = _get_padded_num_reqs_with_upper_limit(
64, self.max_num_reqs)
# NOTE(chengjiyao): In total, the compute_logits function utilizes a
# compilation cache size of token_bucket_num multiplied by
# req_bucket_num. This is acceptable, given the graph's relatively
# small size.
while True:
logits_indices = torch.zeros(
num_reqs,
dtype=torch.int32,
device=self.device,
)
torch._dynamo.mark_dynamic(hidden_states, 0)
torch._dynamo.mark_dynamic(logits_indices, 0)
self.model.compute_logits(hidden_states, logits_indices, None)
if num_reqs >= self.max_num_reqs:
break
num_reqs = _get_padded_num_reqs_with_upper_limit(
num_reqs + 1, self.max_num_reqs)

def capture_model(self) -> None:
"""Compile the model."""
Expand Down Expand Up @@ -823,13 +843,17 @@ def forward(

return hidden_states

@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
def compute_logits(
self,
hidden_states: torch.Tensor,
logits_indices: torch.Tensor,
sampling_metadata,
) -> Optional[torch.Tensor]:
hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(hidden_states, sampling_metadata)
return logits
selected_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
return selected_token_ids

def get_multimodal_embeddings(self, *args, **kwargs):
return self.model.get_multimodal_embeddings(*args, **kwargs)
Expand All @@ -846,3 +870,8 @@ def _get_padded_token_len(x: int) -> int:
if x <= 16:
return 16
return 1 << (x - 1).bit_length()


def _get_padded_num_reqs_with_upper_limit(x, upper_limit) -> int:
res = 64 if x <= 64 else 1 << (x - 1).bit_length()
return min(res, upper_limit)