Skip to content

Commit 2d9ce44

Browse files
duj12dujingBarabazs
authored
fix(asr): load VAD model on correct CUDA device (#835)
fix(asr): load VAD model on correct CUDA device Previously, the VAD sub‐model was always initialized on the default CUDA device (cuda:0), even when a higher device_index was specified. This change sets `device_vad` to `cuda:{device_index}` whenever `device == 'cuda'`, while falling back to the original `device` string for non‐CUDA cases. This ensures the VAD model is loaded on the intended GPU. Co-authored-by: dujing <[email protected]> Co-authored-by: Barabazs <[email protected]>
1 parent f4261f3 commit 2d9ce44

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

whisperx/asr.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,11 @@ def load_model(
401401
if vad_method == "silero":
402402
vad_model = Silero(**default_vad_options)
403403
elif vad_method == "pyannote":
404-
vad_model = Pyannote(torch.device(device), use_auth_token=None, **default_vad_options)
404+
if device == 'cuda':
405+
device_vad = f'cuda:{device_index}'
406+
else:
407+
device_vad = device
408+
vad_model = Pyannote(torch.device(device_vad), use_auth_token=None, **default_vad_options)
405409
else:
406410
raise ValueError(f"Invalid vad_method: {vad_method}")
407411

0 commit comments

Comments
 (0)