Skip to content

Commit ced790b

Browse files
youkaichaogarg-amit
authored andcommitted
[torch.compile] fix tensor alias (vllm-project#8982)
Signed-off-by: Amit Garg <[email protected]>
1 parent 31e005d commit ced790b

File tree

3 files changed

+9
-3
lines changed

3 files changed

+9
-3
lines changed

vllm/worker/embedding_model_runner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,8 @@ def execute_model(
103103
# a placeholder (it has wide hardware support).
104104
kv_caches = [
105105
torch.tensor([], dtype=torch.float32, device=self.device)
106-
] * num_layers
106+
for _ in range(num_layers)
107+
]
107108

108109
execute_model_kwargs = {
109110
"input_ids":

vllm/worker/enc_dec_model_runner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,8 @@ def profile_run(self) -> None:
348348
# a placeholder (it has wide hardware support).
349349
kv_caches = [
350350
torch.tensor([], dtype=torch.float32, device=self.device)
351-
] * num_layers
351+
for _ in range(num_layers)
352+
]
352353
finished_requests_ids = [seq.request_id for seq in seqs]
353354
model_input = self.prepare_model_input(
354355
seqs, finished_requests_ids=finished_requests_ids)

vllm/worker/model_runner.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1276,9 +1276,13 @@ def profile_run(self) -> None:
12761276
# it by reference, rather by specializing on the value ``None``.
12771277
# the `dtype` argument does not matter, and we use `float32` as
12781278
# a placeholder (it has wide hardware support).
1279+
# it is important to create tensors inside the loop, rather than
1280+
# multiplying the list, to avoid Dynamo from treating them as
1281+
# tensor aliasing.
12791282
kv_caches = [
12801283
torch.tensor([], dtype=torch.float32, device=self.device)
1281-
] * num_layers
1284+
for _ in range(num_layers)
1285+
]
12821286
finished_requests_ids = [seq.request_id for seq in seqs]
12831287
model_input = self.prepare_model_input(
12841288
seqs, finished_requests_ids=finished_requests_ids)

0 commit comments

Comments
 (0)