Skip to content

Commit f04f5df

Browse files
authored
Use CUDA for audio decoding unless you are on 'mps' and need to use CPU (#98)
Previously, both CUDA and MPS tensors were being moved to CPU for audio decoding. This was only necessary for MPS devices due to embedding operation limitations. Now CUDA tensors remain on GPU for decoding, improving performance while maintaining MPS compatibility.
1 parent 14f5f72 commit f04f5df

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

examples/generation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ def generate(
374374
concat_audio_out_ids = torch.concat(audio_out_ids_l, dim=1)
375375

376376
# Fix MPS compatibility: detach and move to CPU before decoding
377-
if concat_audio_out_ids.device.type in ["mps", "cuda"]:
377+
if concat_audio_out_ids.device.type == "mps":
378378
concat_audio_out_ids_cpu = concat_audio_out_ids.detach().cpu()
379379
else:
380380
concat_audio_out_ids_cpu = concat_audio_out_ids

0 commit comments

Comments
 (0)