Skip to content

Commit 4727865

Browse files
KazusatoOokoKazusatoOko
authored andcommitted
[BugFix]: Batch generation from prompt_embeds fails for long prompts (vllm-project#21390)
Signed-off-by: KazusatoOko <[email protected]> Co-authored-by: KazusatoOko <[email protected]> Signed-off-by: Paul Pak <[email protected]>
1 parent 66cae7e commit 4727865

File tree

1 file changed

+22
-14
lines changed

1 file changed

+22
-14
lines changed

vllm/worker/model_runner.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1785,24 +1785,32 @@ def execute_model(
17851785

17861786
if model_input.inputs_embeds is not None:
17871787
if self.is_driver_worker:
1788-
sampled = broadcast_tensor_dict(
1789-
{"token_ids": output.sampled_token_ids})
1788+
sampled_token_ids = []
1789+
valid_outputs = []
1790+
for sequence_group_output in output.outputs:
1791+
if len(sequence_group_output.samples) == 0:
1792+
continue
1793+
assert len(sequence_group_output.samples) == 1
1794+
valid_outputs.append(sequence_group_output)
1795+
sampled_token_ids.append(
1796+
sequence_group_output.samples[0].output_token)
1797+
sampled_token_ids = torch.tensor(sampled_token_ids).to(
1798+
self.device)
1799+
sampled_token_ids = broadcast_tensor_dict(
1800+
{"sampled_token_ids":
1801+
sampled_token_ids})["sampled_token_ids"]
17901802
else:
1791-
sampled = broadcast_tensor_dict()
1792-
if sampled["token_ids"] is not None:
1793-
sampled_token_embeds = self.model.get_input_embeddings(
1794-
sampled["token_ids"].squeeze(1))
1803+
sampled_token_ids = broadcast_tensor_dict(
1804+
)["sampled_token_ids"]
1805+
if len(sampled_token_ids) > 0:
1806+
sampled_token_embeds = \
1807+
self.model.get_input_embeddings(sampled_token_ids)
17951808
if self.is_driver_worker:
17961809
self.sampler.include_gpu_probs_tensor = \
17971810
orig_include_gpu_probs
1798-
1799-
output.sampled_token_embeds = sampled_token_embeds
1800-
1801-
for token_embed, sequence_group_output in zip(
1802-
output.sampled_token_embeds, output.outputs):
1803-
assert len(sequence_group_output.samples) == 1
1804-
sequence_group_output.samples[
1805-
0].output_embed = token_embed
1811+
for i, sequence_group_output in enumerate(valid_outputs):
1812+
sequence_group_output.samples[0].output_embed = \
1813+
sampled_token_embeds[i]
18061814

18071815
if not self.is_driver_worker:
18081816
return []

0 commit comments

Comments
 (0)