|
18 | 18 | from vllm import envs
|
19 | 19 | from vllm.attention import AttentionMetadata
|
20 | 20 | from vllm.config import VllmConfig
|
21 |
| -from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn |
| 21 | +from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn |
22 | 22 | from vllm.model_executor.layers.layernorm import RMSNorm
|
23 | 23 | from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
24 | 24 | from vllm.model_executor.model_loader.loader import DefaultModelLoader
|
@@ -252,33 +252,50 @@ def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor:
|
252 | 252 | return audio_embeds
|
253 | 253 |
|
254 | 254 |
|
| 255 | +class FlippedSiluAndMul(SiluAndMul): |
| 256 | + """Ultravox is trained with SwiGLU with flipped halves.""" |
| 257 | + |
| 258 | + def forward(self, x: torch.Tensor): |
| 259 | + a, b = x.chunk(2, dim=-1) |
| 260 | + flipped = torch.cat((b, a), dim=-1) |
| 261 | + return super().forward(flipped) |
| 262 | + |
| 263 | + |
255 | 264 | class UltravoxProjector(nn.Module):
|
256 | 265 |
|
257 | 266 | def __init__(self, config: UltravoxConfig):
|
258 | 267 | super().__init__()
|
259 | 268 | self.hidden_dim = config.hidden_size
|
260 | 269 | self._pad_and_stack = StackAudioFrames(config.stack_factor)
|
261 |
| - dim = config.audio_config.hidden_size * config.stack_factor |
262 |
| - self.ln_pre = RMSNorm(dim) |
263 |
| - self.linear_1 = nn.Linear(dim, self.hidden_dim, bias=False) |
264 |
| - dim = self.hidden_dim |
| 270 | + dim_in = config.audio_config.hidden_size * config.stack_factor |
| 271 | + self.ln_pre = RMSNorm(dim_in) |
| 272 | + self.linear_1 = nn.Linear(dim_in, self.hidden_dim, bias=False) |
| 273 | + dim_mid = self.hidden_dim |
265 | 274 |
|
266 | 275 | if config.projector_act == "swiglu":
|
267 |
| - self.act = MulAndSilu() |
268 |
| - dim = dim // 2 |
| 276 | + self.act = FlippedSiluAndMul() |
| 277 | + dim_mid = dim_mid // 2 |
269 | 278 | else:
|
270 | 279 | self.act = get_act_fn(config.projector_act)
|
271 | 280 |
|
272 |
| - self.linear_2 = nn.Linear(dim, |
273 |
| - config.text_config.hidden_size, |
274 |
| - bias=False) |
275 |
| - self.ln_post = RMSNorm(config.text_config.hidden_size) |
| 281 | + dim_out = config.text_config.hidden_size |
| 282 | + self.linear_2 = nn.Linear(dim_mid, dim_out, bias=False) |
| 283 | + |
| 284 | + # Ultravox v0.4.1 and below uses layer_norm after the second linear layer, |
| 285 | + # while v0.5.0 and above uses layer_norm after the first linear layer. |
| 286 | + if config.projector_ln_mid: |
| 287 | + self.ln_mid: nn.Module = RMSNorm(dim_mid) |
| 288 | + self.ln_post = nn.Identity() |
| 289 | + else: |
| 290 | + self.ln_mid = nn.Identity() |
| 291 | + self.ln_post = RMSNorm(dim_out) |
276 | 292 |
|
277 | 293 | def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
|
278 | 294 | audio_features = self._pad_and_stack(audio_features)
|
279 | 295 | audio_features = self.ln_pre(audio_features)
|
280 | 296 | hidden_states = self.linear_1(audio_features)
|
281 | 297 | hidden_states = self.act(hidden_states)
|
| 298 | + hidden_states = self.ln_mid(hidden_states) |
282 | 299 | hidden_states = self.linear_2(hidden_states)
|
283 | 300 | hidden_states = self.ln_post(hidden_states)
|
284 | 301 | return hidden_states
|
|
0 commit comments