Skip to content

Commit f385b21

Browse files
Blaizzylucasnewman
andauthored
Server v2 (#153)
* base arch of server * add tts and stt endpoints * functioning server * connect server and ui * Add audio utilities, use them where possible (#161) * Add audio utilities, use them where possible. * Formatting. * Fix tests. * Fix tests. * More test fixes. * fix server * Fix join audio sample rate (#162) * update nextjs * fix stt view * working STT * working text to speech * remove voices * remove home * add custom model and delete file * refactor model mapping * add animation and use env vars for frontend config * remove unused * refactor model loading * add tests * mock generate * fix tests * remove old player * update readme --------- Co-authored-by: Lucas Newman <[email protected]>
1 parent 9042b56 commit f385b21

File tree

26 files changed

+1536
-2675
lines changed

26 files changed

+1536
-2675
lines changed

README.md

Lines changed: 56 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -64,27 +64,40 @@ print("Audiobook chapter successfully generated!")
6464

6565
```
6666

67-
### Web Interface & API Server
68-
69-
MLX-Audio includes a web interface with a 3D visualization that reacts to audio frequencies. The interface allows you to:
70-
71-
1. Generate TTS with different voices and speed settings
72-
2. Upload and play your own audio files
73-
3. Visualize audio with an interactive 3D orb
74-
4. Automatically saves generated audio files to the outputs directory in the current working folder
75-
5. Open the output folder directly from the interface (when running locally)
76-
77-
#### Features
67+
### Web Interface & FastAPI Server
68+
69+
MLX-Audio provides a modern web interface with real-time audio visualization capabilities. The interface offers:
70+
71+
1. Text-to-Speech generation with customizable voices and parameters
72+
2. Speech-to-Text transcription with support for multiple languages
73+
3. Audio file upload and playback functionality
74+
4. Interactive 3D audio visualization
75+
5. Automatic audio file management in the outputs directory
76+
6. Direct access to the output folder from the interface (local deployment only)
77+
78+
#### Key Features
79+
80+
- **Voice Customization**: Select from multiple voice presets including AF Heart, AF Nova, AF Bella, and BF Emma
81+
- **Speech Rate Control**: Fine-tune speech generation speed using an intuitive slider (range: 0.5x - 2.0x)
82+
- **Dynamic 3D Visualization**: Experience audio through an interactive 3D orb that responds to frequency changes
83+
- **Audio Management**: Upload, play, and visualize custom audio files
84+
- **Smart Playback**: Optional automatic playback of generated audio
85+
- **File Management**: Quick access to the output directory through an integrated file explorer button
86+
- **Speech Recognition**: Convert speech to text with support for multiple languages and models
87+
To start the web interface and API server:
7888

79-
- **Multiple Voice Options**: Choose from different voice styles (AF Heart, AF Nova, AF Bella, BF Emma)
80-
- **Adjustable Speech Speed**: Control the speed of speech generation with an interactive slider (0.5x to 2.0x)
81-
- **Real-time 3D Visualization**: A responsive 3D orb that reacts to audio frequencies
82-
- **Audio Upload**: Play and visualize your own audio files
83-
- **Auto-play Option**: Automatically play generated audio
84-
- **Output Folder Access**: Convenient button to open the output folder in your system's file explorer
89+
UI:
90+
```bash
91+
# Configure the API base URL and port
92+
export NEXT_PUBLIC_API_BASE_URL=http://localhost
93+
export NEXT_PUBLIC_API_PORT=8000
8594

86-
To start the web interface and API server:
95+
# Start UI server
96+
cd mlx_audio/ui
97+
npm run dev
98+
```
8799

100+
Server:
88101
```bash
89102
# Using the command-line interface
90103
mlx_audio.server
@@ -109,26 +122,23 @@ http://127.0.0.1:8000
109122

110123
The server provides the following REST API endpoints:
111124

112-
- `POST /tts`: Generate TTS audio
113-
- Parameters (form data):
114-
- `text`: The text to convert to speech (required)
115-
- `voice`: Voice to use (default: "af_heart")
116-
- `speed`: Speech speed from 0.5 to 2.0 (default: 1.0)
117-
- Returns: JSON with filename of generated audio
118-
119-
- `GET /audio/{filename}`: Retrieve generated audio file
120-
121-
- `POST /play`: Play audio directly from the server
122-
- Parameters (form data):
123-
- `filename`: The filename of the audio to play (required)
124-
- Returns: JSON with status and filename
125+
- `POST /v1/audio/speech`: Generate speech from text following the OpenAI TTS specification.
126+
- JSON body parameters:
127+
- `model`: Name or path of the TTS model to use.
128+
- `input`: Text to convert to speech.
129+
- `voice`: Optional voice preset.
130+
- `speed`: Optional speech speed (default `1.0`).
131+
- Returns the generated audio in WAV format.
125132

126-
- `POST /stop`: Stop any currently playing audio
127-
- Returns: JSON with status
133+
- `POST /v1/audio/transcriptions`: Transcribe audio files using an STT model in a format compatible with OpenAI's API.
134+
- Multipart form parameters:
135+
- `file`: The audio file to transcribe.
136+
- `model`: Name or path of the STT model.
137+
- Returns JSON containing the transcribed `text`.
128138

129-
- `POST /open_output_folder`: Open the output folder in the system's file explorer
130-
- Returns: JSON with status and path
131-
- Note: This feature only works when running the server locally
139+
- `GET /v1/models`: List loaded models.
140+
- `POST /v1/models`: Load a model by name.
141+
- `DELETE /v1/models`: Unload a model.
132142

133143
> Note: Generated audio files are stored in `~/.mlx_audio/outputs` by default, or in a fallback directory if that location is not writable.
134144
@@ -217,7 +227,7 @@ mx.save_safetensors("./8bit/kokoro-v1_0.safetensors", weights, metadata={"format
217227
- For the web interface and API:
218228
- FastAPI
219229
- Uvicorn
220-
230+
221231
## License
222232

223233
[MIT License](LICENSE)
@@ -227,3 +237,12 @@ mx.save_safetensors("./8bit/kokoro-v1_0.safetensors", weights, metadata={"format
227237
- Thanks to the Apple MLX team for providing a great framework for building TTS and STS models.
228238
- This project uses the Kokoro model architecture for text-to-speech synthesis.
229239
- The 3D visualization uses Three.js for rendering.
240+
241+
242+
@misc{mlx-audio,
243+
author = {Canuma, Prince},
244+
title = {MLX Audio},
245+
year = {2025},
246+
howpublished = {\url{https://github.com/Blaizzy/mlx-audio}},
247+
note = {A text-to-speech (TTS), speech-to-text (STT) and speech-to-speech (STS) library built on Apple's MLX framework, providing efficient speech analysis on Apple Silicon.}
248+
}

mlx_audio/codec/models/vocos/mel.py

Lines changed: 2 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -1,152 +1,8 @@
11
from __future__ import annotations
22

3-
import math
4-
from functools import lru_cache
5-
from typing import Optional
6-
73
import mlx.core as mx
84

9-
10-
@lru_cache(maxsize=None)
11-
def mel_filters(
12-
sample_rate: int,
13-
n_fft: int,
14-
n_mels: int,
15-
f_min: float = 0,
16-
f_max: Optional[float] = None,
17-
norm: Optional[str] = None,
18-
mel_scale: str = "htk",
19-
) -> mx.array:
20-
def hz_to_mel(freq, mel_scale="htk"):
21-
if mel_scale == "htk":
22-
return 2595.0 * math.log10(1.0 + freq / 700.0)
23-
24-
# slaney scale
25-
f_min, f_sp = 0.0, 200.0 / 3
26-
mels = (freq - f_min) / f_sp
27-
min_log_hz = 1000.0
28-
min_log_mel = (min_log_hz - f_min) / f_sp
29-
logstep = math.log(6.4) / 27.0
30-
if freq >= min_log_hz:
31-
mels = min_log_mel + math.log(freq / min_log_hz) / logstep
32-
return mels
33-
34-
def mel_to_hz(mels, mel_scale="htk"):
35-
if mel_scale == "htk":
36-
return 700.0 * (10.0 ** (mels / 2595.0) - 1.0)
37-
38-
# slaney scale
39-
f_min, f_sp = 0.0, 200.0 / 3
40-
freqs = f_min + f_sp * mels
41-
min_log_hz = 1000.0
42-
min_log_mel = (min_log_hz - f_min) / f_sp
43-
logstep = math.log(6.4) / 27.0
44-
freqs = mx.where(
45-
mels >= min_log_mel,
46-
min_log_hz * mx.exp(logstep * (mels - min_log_mel)),
47-
freqs,
48-
)
49-
return freqs
50-
51-
f_max = f_max or sample_rate / 2
52-
53-
# generate frequency points
54-
55-
n_freqs = n_fft // 2 + 1
56-
all_freqs = mx.linspace(0, sample_rate // 2, n_freqs)
57-
58-
# convert frequencies to mel and back to hz
59-
60-
m_min = hz_to_mel(f_min, mel_scale)
61-
m_max = hz_to_mel(f_max, mel_scale)
62-
m_pts = mx.linspace(m_min, m_max, n_mels + 2)
63-
f_pts = mel_to_hz(m_pts, mel_scale)
64-
65-
# compute slopes for filterbank
66-
67-
f_diff = f_pts[1:] - f_pts[:-1]
68-
slopes = mx.expand_dims(f_pts, 0) - mx.expand_dims(all_freqs, 1)
69-
70-
# calculate overlapping triangular filters
71-
72-
down_slopes = (-slopes[:, :-2]) / f_diff[:-1]
73-
up_slopes = slopes[:, 2:] / f_diff[1:]
74-
filterbank = mx.maximum(
75-
mx.zeros_like(down_slopes), mx.minimum(down_slopes, up_slopes)
76-
)
77-
78-
if norm == "slaney":
79-
enorm = 2.0 / (f_pts[2 : n_mels + 2] - f_pts[:n_mels])
80-
filterbank *= mx.expand_dims(enorm, 0)
81-
82-
filterbank = filterbank.moveaxis(0, 1)
83-
return filterbank
84-
85-
86-
@lru_cache(maxsize=None)
87-
def hanning(size):
88-
return mx.array(
89-
[0.5 * (1 - math.cos(2 * math.pi * n / (size - 1))) for n in range(size)]
90-
)
91-
92-
93-
def stft(x, window, nperseg=256, noverlap=None, nfft=None, pad_mode="reflect"):
94-
if nfft is None:
95-
nfft = nperseg
96-
if noverlap is None:
97-
noverlap = nfft // 4
98-
99-
def _pad(x, padding, pad_mode="constant"):
100-
if pad_mode == "constant":
101-
return mx.pad(x, [(padding, padding)])
102-
elif pad_mode == "reflect":
103-
prefix = x[1 : padding + 1][::-1]
104-
suffix = x[-(padding + 1) : -1][::-1]
105-
return mx.concatenate([prefix, x, suffix])
106-
else:
107-
raise ValueError(f"Invalid pad_mode {pad_mode}")
108-
109-
if window.shape[0] < nfft:
110-
pad_left = (nfft - window.shape[0]) // 2
111-
pad_right = nfft - window.shape[0] - pad_left
112-
window = mx.pad(window, (pad_left, pad_right))
113-
114-
padding = nfft // 2
115-
x = _pad(x, padding, pad_mode)
116-
117-
strides = [noverlap, 1]
118-
t = (x.size - nperseg + noverlap) // noverlap
119-
shape = [t, nfft]
120-
x = mx.as_strided(x, shape=shape, strides=strides)
121-
return mx.fft.rfft(x * window)
122-
123-
124-
def istft(x, window, nperseg=256, noverlap=None, nfft=None):
125-
if nfft is None:
126-
nfft = nperseg
127-
if noverlap is None:
128-
noverlap = nfft // 4
129-
130-
t = (x.shape[0] - 1) * noverlap + nperseg
131-
reconstructed = mx.zeros(t)
132-
window_sum = mx.zeros(t)
133-
134-
for i in range(x.shape[0]):
135-
# inverse FFT of each frame
136-
frame_time = mx.fft.irfft(x[i])
137-
138-
# get the position in the time-domain signal to add the frame
139-
start = i * noverlap
140-
end = start + nperseg
141-
142-
# overlap-add the inverse transformed frame, scaled by the window
143-
reconstructed[start:end] += frame_time * window
144-
window_sum[start:end] += window
145-
146-
# normalize by the sum of the window values
147-
reconstructed = mx.where(window_sum != 0, reconstructed / window_sum, reconstructed)
148-
149-
return reconstructed
5+
from mlx_audio.utils import hanning, mel_filters, stft
1506

1517

1528
def log_mel_spectrogram(
@@ -163,7 +19,7 @@ def log_mel_spectrogram(
16319
if padding > 0:
16420
audio = mx.pad(audio, (0, padding))
16521

166-
freqs = stft(audio, hanning(n_fft), nperseg=n_fft, noverlap=hop_length)
22+
freqs = stft(audio, window=hanning(n_fft), n_fft=n_fft, win_length=hop_length)
16723
magnitudes = freqs[:-1, :].abs()
16824
filters = mel_filters(
16925
sample_rate=sample_rate,

mlx_audio/codec/models/vocos/vocos.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99
import yaml
1010
from huggingface_hub import snapshot_download
1111

12+
from mlx_audio.utils import hanning, istft
13+
1214
from ..encodec import Encodec
13-
from .mel import hanning, istft, log_mel_spectrogram
15+
from .mel import log_mel_spectrogram
1416

1517

1618
class FeatureExtractor(nn.Module):
@@ -130,11 +132,10 @@ def __call__(self, x: mx.array) -> mx.array:
130132
y = mx.sin(p)
131133
S = mag * (x + 1j * y)
132134
audio = istft(
133-
S.squeeze(0).swapaxes(0, 1),
134-
hanning(self.n_fft),
135-
self.n_fft,
136-
self.hop_length,
137-
self.n_fft,
135+
S.squeeze(0),
136+
window=hanning(self.n_fft),
137+
hop_length=self.hop_length,
138+
win_length=self.n_fft,
138139
)
139140
return audio
140141

mlx_audio/codec/tests/test_vocos.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,12 @@ def test_vocos_24khz(self):
6565

6666
# reconstruct from mel spec
6767
reconstructed_audio = model(audio)
68-
self.assertEqual(reconstructed_audio.shape, (120576,))
68+
self.assertEqual(reconstructed_audio.shape, (119552,))
6969

7070
# decode from mel spec
7171
mel_spec = log_mel_spectrogram(audio)
7272
decoded = model.decode(mel_spec)
73-
self.assertEqual(decoded.shape, (120576,))
73+
self.assertEqual(decoded.shape, (119552,))
7474

7575
model = Vocos.from_hparams(config_encodec)
7676

@@ -79,14 +79,14 @@ def test_vocos_24khz(self):
7979
reconstructed_audio = model(
8080
audio, bandwidth_id=mx.array(bandwidth_id)[None, ...]
8181
)
82-
self.assertEqual(reconstructed_audio.shape, (120960,))
82+
self.assertEqual(reconstructed_audio.shape, (119680,))
8383

8484
# decode with encodec codes
8585
codes = model.get_encodec_codes(audio, bandwidth_id=bandwidth_id)
8686
decoded = model.decode_from_codes(
8787
codes, bandwidth_id=mx.array(bandwidth_id)[None, ...]
8888
)
89-
self.assertEqual(decoded.shape, (120960,))
89+
self.assertEqual(decoded.shape, (119680,))
9090

9191

9292
if __name__ == "__main__":

0 commit comments

Comments
 (0)