Skip to content

Commit 77aaefa

Browse files
lucasnewmanBlaizzy
andauthored
Add voice matching support for Dia (#93)
* Add voice matching support for Dia. * Fix codec test failures from MLX update. --------- Co-authored-by: Prince Canuma <[email protected]>
1 parent 8652b35 commit 77aaefa

File tree

4 files changed

+24
-27
lines changed

4 files changed

+24
-27
lines changed

mlx_audio/codec/tests/test_descript.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def test_descript_16khz(self):
3838
self.assertEqual(latents.shape, (1, 96, 250))
3939

4040
y = model.decode(z).squeeze(-1)
41-
self.assertEqual(y.shape, (1, 79_992))
41+
self.assertEqual(y.shape, (1, 80_043))
4242

4343
def test_descript_24khz(self):
4444
audio = mx.zeros((1, 1, 120_000))
@@ -70,7 +70,7 @@ def test_descript_24khz(self):
7070
self.assertEqual(latents.shape, (1, 256, 375))
7171

7272
y = model.decode(z).squeeze(-1)
73-
self.assertEqual(y.shape, (1, 119_992))
73+
self.assertEqual(y.shape, (1, 120_043))
7474

7575
def test_descript_44khz(self):
7676
audio = mx.zeros((1, 1, 220_000))
@@ -102,7 +102,7 @@ def test_descript_44khz(self):
102102
self.assertEqual(latents.shape, (1, 72, 430))
103103

104104
y = model.decode(z).squeeze(-1)
105-
self.assertEqual(y.shape, (1, 220_160))
105+
self.assertEqual(y.shape, (1, 220_235))
106106

107107

108108
if __name__ == "__main__":

mlx_audio/codec/tests/test_snac.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def test_snac(self):
3333
self.assertEqual(codes[2].shape, (1, 236))
3434

3535
reconstructed = model.decode(codes).squeeze(-1)
36-
self.assertEqual(reconstructed.shape, (1, 120_832))
36+
self.assertEqual(reconstructed.shape, (1, 120_907))
3737

3838

3939
if __name__ == "__main__":

mlx_audio/tts/generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def generate_audio(
7474
if ref_audio:
7575
if not os.path.exists(ref_audio):
7676
raise FileNotFoundError(f"Reference audio file not found: {ref_audio}")
77-
ref_audio = load_audio(ref_audio)
77+
ref_audio = load_audio(ref_audio, sample_rate=sample_rate)
7878
if not ref_text:
7979
print("Ref_text not found. Transcribing ref_audio...")
8080
# mlx_whisper seems takes long time to import. Import only necessary.

mlx_audio/tts/models/dia/dia.py

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,8 @@
44
import mlx.core as mx
55
import mlx.nn as nn
66
import numpy as np
7-
import soundfile as sf
87
from huggingface_hub import hf_hub_download
98
from mlx_lm.sample_utils import make_sampler
10-
from scipy import signal
119
from tqdm import trange
1210

1311
from mlx_audio.codec.models import DAC
@@ -18,14 +16,6 @@
1816
from .layers import DiaModel, KVCache
1917

2018

21-
def resample_audio(audio: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray:
22-
gcd = np.gcd(orig_sr, target_sr)
23-
up = target_sr // gcd
24-
down = orig_sr // gcd
25-
resampled = signal.resample_poly(audio, up, down, padtype="edge")
26-
return resampled
27-
28-
2919
def _sample_next_token(
3020
logits_BCxV: mx.array,
3121
temperature: float,
@@ -226,6 +216,8 @@ def generate(
226216
split_pattern: str = "\n",
227217
max_tokens: int | None = None,
228218
verbose: bool = False,
219+
ref_audio: Optional[mx.array] = None,
220+
ref_text: Optional[str] = None,
229221
**kwargs,
230222
):
231223
prompt = text.replace("\\n", "\n").replace("\\t", "\t")
@@ -239,6 +231,8 @@ def generate(
239231
audio = self._generate(
240232
prompt,
241233
max_tokens=max_tokens,
234+
ref_audio=ref_audio,
235+
ref_text=ref_text,
242236
)
243237
all_audio.append(audio[None, ...])
244238

@@ -291,13 +285,14 @@ def generate(
291285
def _generate(
292286
self,
293287
text: str,
294-
max_tokens: int | None = None,
288+
max_tokens: Optional[int] = None,
295289
cfg_scale: float = 3.0,
296290
temperature: float = 1.3,
297291
top_p: float = 0.95,
298292
use_cfg_filter: bool = True,
299293
cfg_filter_top_k: int = 35,
300-
audio_prompt_path: str | None = None,
294+
ref_audio: Optional[mx.array] = None,
295+
ref_text: Optional[str] = None,
301296
) -> np.ndarray:
302297
"""
303298
Generates audio from a text prompt (and optional audio prompt) using the Dia model.
@@ -314,6 +309,9 @@ def _generate(
314309
delay_tensor = mx.array(delay_pattern, dtype=mx.int32)
315310
max_delay_pattern = max(delay_pattern)
316311

312+
if ref_text is not None:
313+
text = ref_text.strip() + " " + text
314+
317315
(
318316
cond_src_BxS,
319317
cond_src_positions_BxS,
@@ -370,19 +368,18 @@ def _generate(
370368
prompt_len_inc_bos = 1 # Start with BOS length
371369

372370
# 3-3. Load Audio Prompt (if provided)
373-
if audio_prompt_path is not None:
374-
audio_prompt, sr = sf.read(audio_prompt_path) # C, T
375-
if sr != 44100: # Resample to 44.1kHz
376-
audio_prompt = resample_audio(audio_prompt, sr, 44100)
377-
audio_prompt = audio_prompt.unsqueeze(0) # 1, C, T
371+
if ref_audio is not None:
372+
audio_prompt = mx.array(ref_audio)[None, None, ...] # 1, C, T
378373

379374
audio_prompt_codebook = audio_to_codebook(
380375
self.dac_model, audio_prompt, data_config=self.config.data
381376
)
382-
audio_prompt_mx = mx.array(audio_prompt_codebook.numpy())
383-
384-
audio_prompt_mx = mx.concatenate([audio_prompt_mx, audio_prompt_mx], axis=0)
385-
generated_BxTxC = mx.concatenate([generated_BxTxC, audio_prompt_mx], axis=1)
377+
audio_prompt_codebook = mx.concatenate(
378+
[audio_prompt_codebook, audio_prompt_codebook], axis=0
379+
)
380+
generated_BxTxC = mx.concatenate(
381+
[generated_BxTxC, audio_prompt_codebook], axis=1
382+
)
386383

387384
prefill_len = generated_BxTxC.shape[1]
388385
prompt_len_inc_bos = prefill_len
@@ -499,7 +496,7 @@ def _generate(
499496
)
500497

501498
generation_step_index = step - current_step
502-
if audio_prompt_path is None:
499+
if ref_audio is None:
503500
pred_C = mx.where(
504501
generation_step_index >= delay_tensor,
505502
pred_C,

0 commit comments

Comments
 (0)