44import mlx .core as mx
55import mlx .nn as nn
66import numpy as np
7- import soundfile as sf
87from huggingface_hub import hf_hub_download
98from mlx_lm .sample_utils import make_sampler
10- from scipy import signal
119from tqdm import trange
1210
1311from mlx_audio .codec .models import DAC
1816from .layers import DiaModel , KVCache
1917
2018
21- def resample_audio (audio : np .ndarray , orig_sr : int , target_sr : int ) -> np .ndarray :
22- gcd = np .gcd (orig_sr , target_sr )
23- up = target_sr // gcd
24- down = orig_sr // gcd
25- resampled = signal .resample_poly (audio , up , down , padtype = "edge" )
26- return resampled
27-
28-
2919def _sample_next_token (
3020 logits_BCxV : mx .array ,
3121 temperature : float ,
@@ -226,6 +216,8 @@ def generate(
226216 split_pattern : str = "\n " ,
227217 max_tokens : int | None = None ,
228218 verbose : bool = False ,
219+ ref_audio : Optional [mx .array ] = None ,
220+ ref_text : Optional [str ] = None ,
229221 ** kwargs ,
230222 ):
231223 prompt = text .replace ("\\ n" , "\n " ).replace ("\\ t" , "\t " )
@@ -239,6 +231,8 @@ def generate(
239231 audio = self ._generate (
240232 prompt ,
241233 max_tokens = max_tokens ,
234+ ref_audio = ref_audio ,
235+ ref_text = ref_text ,
242236 )
243237 all_audio .append (audio [None , ...])
244238
@@ -291,13 +285,14 @@ def generate(
291285 def _generate (
292286 self ,
293287 text : str ,
294- max_tokens : int | None = None ,
288+ max_tokens : Optional [ int ] = None ,
295289 cfg_scale : float = 3.0 ,
296290 temperature : float = 1.3 ,
297291 top_p : float = 0.95 ,
298292 use_cfg_filter : bool = True ,
299293 cfg_filter_top_k : int = 35 ,
300- audio_prompt_path : str | None = None ,
294+ ref_audio : Optional [mx .array ] = None ,
295+ ref_text : Optional [str ] = None ,
301296 ) -> np .ndarray :
302297 """
303298 Generates audio from a text prompt (and optional audio prompt) using the Dia model.
@@ -314,6 +309,9 @@ def _generate(
314309 delay_tensor = mx .array (delay_pattern , dtype = mx .int32 )
315310 max_delay_pattern = max (delay_pattern )
316311
312+ if ref_text is not None :
313+ text = ref_text .strip () + " " + text
314+
317315 (
318316 cond_src_BxS ,
319317 cond_src_positions_BxS ,
@@ -370,19 +368,18 @@ def _generate(
370368 prompt_len_inc_bos = 1 # Start with BOS length
371369
372370 # 3-3. Load Audio Prompt (if provided)
373- if audio_prompt_path is not None :
374- audio_prompt , sr = sf .read (audio_prompt_path ) # C, T
375- if sr != 44100 : # Resample to 44.1kHz
376- audio_prompt = resample_audio (audio_prompt , sr , 44100 )
377- audio_prompt = audio_prompt .unsqueeze (0 ) # 1, C, T
371+ if ref_audio is not None :
372+ audio_prompt = mx .array (ref_audio )[None , None , ...] # 1, C, T
378373
379374 audio_prompt_codebook = audio_to_codebook (
380375 self .dac_model , audio_prompt , data_config = self .config .data
381376 )
382- audio_prompt_mx = mx .array (audio_prompt_codebook .numpy ())
383-
384- audio_prompt_mx = mx .concatenate ([audio_prompt_mx , audio_prompt_mx ], axis = 0 )
385- generated_BxTxC = mx .concatenate ([generated_BxTxC , audio_prompt_mx ], axis = 1 )
377+ audio_prompt_codebook = mx .concatenate (
378+ [audio_prompt_codebook , audio_prompt_codebook ], axis = 0
379+ )
380+ generated_BxTxC = mx .concatenate (
381+ [generated_BxTxC , audio_prompt_codebook ], axis = 1
382+ )
386383
387384 prefill_len = generated_BxTxC .shape [1 ]
388385 prompt_len_inc_bos = prefill_len
@@ -499,7 +496,7 @@ def _generate(
499496 )
500497
501498 generation_step_index = step - current_step
502- if audio_prompt_path is None :
499+ if ref_audio is None :
503500 pred_C = mx .where (
504501 generation_step_index >= delay_tensor ,
505502 pred_C ,
0 commit comments