Skip to content

Commit 98f31cd

Browse files
authored
add emtpy_cache() after each padding (opendatahub-io#120)
1 parent c034d5d commit 98f31cd

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

vllm/model_executor/models/mixtral.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,11 @@ def process_weights_after_loading(self):
187187
self.w13_weight = nn.Parameter(F.pad(self.w13_weight.data,
188188
(0, 128), "constant", 0),
189189
requires_grad=False)
190+
torch.cuda.empty_cache()
190191
self.w2_weight = nn.Parameter(F.pad(self.w2_weight.data,
191192
(0, 128), "constant", 0),
192193
requires_grad=False)
194+
torch.cuda.empty_cache()
193195
return
194196

195197
# If checkpoint is fp16, quantize here.

0 commit comments

Comments
 (0)