Skip to content

Commit 35c3e17

Browse files
committed
update ultravox model to support v0.5 release
Signed-off-by: Farzad Abdolhosseini <[email protected]>
1 parent eaa92d4 commit 35c3e17

File tree

2 files changed

+38
-11
lines changed

2 files changed

+38
-11
lines changed

vllm/model_executor/models/ultravox.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from vllm import envs
1919
from vllm.attention import AttentionMetadata
2020
from vllm.config import VllmConfig
21-
from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn
21+
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
2222
from vllm.model_executor.layers.layernorm import RMSNorm
2323
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
2424
from vllm.model_executor.model_loader.loader import DefaultModelLoader
@@ -252,33 +252,50 @@ def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor:
252252
return audio_embeds
253253

254254

255+
class FlippedSiluAndMul(SiluAndMul):
256+
"""Ultravox is trained with SwiGLU with flipped halves."""
257+
258+
def forward(self, x: torch.Tensor):
259+
a, b = x.chunk(2, dim=-1)
260+
flipped = torch.cat((b, a), dim=-1)
261+
return super().forward(flipped)
262+
263+
255264
class UltravoxProjector(nn.Module):
256265

257266
def __init__(self, config: UltravoxConfig):
258267
super().__init__()
259268
self.hidden_dim = config.hidden_size
260269
self._pad_and_stack = StackAudioFrames(config.stack_factor)
261-
dim = config.audio_config.hidden_size * config.stack_factor
262-
self.ln_pre = RMSNorm(dim)
263-
self.linear_1 = nn.Linear(dim, self.hidden_dim, bias=False)
264-
dim = self.hidden_dim
270+
dim_in = config.audio_config.hidden_size * config.stack_factor
271+
self.ln_pre = RMSNorm(dim_in)
272+
self.linear_1 = nn.Linear(dim_in, self.hidden_dim, bias=False)
273+
dim_mid = self.hidden_dim
265274

266275
if config.projector_act == "swiglu":
267-
self.act = MulAndSilu()
268-
dim = dim // 2
276+
self.act = FlippedSiluAndMul()
277+
dim_mid = dim_mid // 2
269278
else:
270279
self.act = get_act_fn(config.projector_act)
271280

272-
self.linear_2 = nn.Linear(dim,
273-
config.text_config.hidden_size,
274-
bias=False)
275-
self.ln_post = RMSNorm(config.text_config.hidden_size)
281+
dim_out = config.text_config.hidden_size
282+
self.linear_2 = nn.Linear(dim_mid, dim_out, bias=False)
283+
284+
# Ultravox v0.4.1 and below uses layer_norm after the second linear layer,
285+
# while v0.5.0 and above uses layer_norm after the first linear layer.
286+
if config.projector_ln_mid:
287+
self.ln_mid: nn.Module = RMSNorm(dim_mid)
288+
self.ln_post = nn.Identity()
289+
else:
290+
self.ln_mid = nn.Identity()
291+
self.ln_post = RMSNorm(dim_out)
276292

277293
def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
278294
audio_features = self._pad_and_stack(audio_features)
279295
audio_features = self.ln_pre(audio_features)
280296
hidden_states = self.linear_1(audio_features)
281297
hidden_states = self.act(hidden_states)
298+
hidden_states = self.ln_mid(hidden_states)
282299
hidden_states = self.linear_2(hidden_states)
283300
hidden_states = self.ln_post(hidden_states)
284301
return hidden_states

vllm/transformers_utils/configs/ultravox.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@ class UltravoxConfig(transformers.PretrainedConfig):
3737
The LoRA configuration for finetuning the text model.
3838
audio_model_lora_config (`LoraConfigSimplified`, *optional*):
3939
The LoRA configuration for finetuning the audio model.
40+
projector_ln_mid (`bool`, *optional*, defaults to `False`):
41+
Whether to apply layer normalization at the middle of the
42+
projector or at the end. Versions v0.4.1 and below
43+
use `False`, but v0.5 and above use `True`.
44+
audio_latency_block_size (`int`, *optional*, defaults to `None`):
45+
The latency block size for simulating audio streaming.
4046
"""
4147

4248
model_type = "ultravox"
@@ -56,6 +62,8 @@ def __init__(
5662
projector_act: str = "swiglu",
5763
text_model_lora_config: Optional[Dict[str, Any]] = None,
5864
audio_model_lora_config: Optional[Dict[str, Any]] = None,
65+
projector_ln_mid: bool = False,
66+
audio_latency_block_size: Optional[int] = None,
5967
**kwargs,
6068
):
6169
self.ignore_index = ignore_index
@@ -68,6 +76,8 @@ def __init__(
6876
self.stack_factor = stack_factor
6977
self.norm_init = norm_init
7078
self.projector_act = projector_act
79+
self.projector_ln_mid = projector_ln_mid
80+
self.audio_latency_block_size = audio_latency_block_size
7181

7282
if text_model_id is not None:
7383
# Avoid circular import

0 commit comments

Comments
 (0)