|
12 | 12 | import transformers
|
13 | 13 | from einops import rearrange
|
14 | 14 | from packaging import version
|
15 |
| -from torch import nn |
16 | 15 |
|
17 | 16 | from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
|
18 |
| -from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY |
| 17 | +from llmfoundry.models.layers.norm import (NORM_CLASS_REGISTRY, LPLayerNorm, |
| 18 | + low_precision_groupnorm) |
19 | 19 |
|
20 | 20 |
|
21 | 21 | def is_flash_v2_installed(v2_version: str = '2.0.0'):
|
@@ -498,6 +498,47 @@ def triton_flash_attn_fn(
|
498 | 498 | return output, None, past_key_value
|
499 | 499 |
|
500 | 500 |
|
| 501 | +def _expand_params(heads: int, param: Optional[torch.Tensor] = None): |
| 502 | + if param is None: |
| 503 | + return None |
| 504 | + return param.repeat(heads) |
| 505 | + |
| 506 | + |
| 507 | +def _apply_qk_gn( |
| 508 | + query: torch.Tensor, |
| 509 | + key: torch.Tensor, |
| 510 | + n_heads: int, |
| 511 | + kv_n_heads: int, |
| 512 | + q_ln: nn.Module, |
| 513 | + k_ln: nn.Module, |
| 514 | +): |
| 515 | + dtype = query.dtype |
| 516 | + |
| 517 | + w = _expand_params(n_heads, q_ln.weight) |
| 518 | + b = _expand_params(n_heads, q_ln.bias) |
| 519 | + if isinstance(q_ln, LPLayerNorm): |
| 520 | + query = low_precision_groupnorm(query, n_heads, w, b, |
| 521 | + eps=q_ln.eps).to(dtype) |
| 522 | + elif isinstance(q_ln, nn.LayerNorm): |
| 523 | + query = nn.functional.group_norm(query, n_heads, w, b, eps=q_ln.eps) |
| 524 | + else: |
| 525 | + raise ValueError( |
| 526 | + f'qk_gn not applicable for given q_ln type ({type(q_ln)=}).') |
| 527 | + |
| 528 | + w = _expand_params(kv_n_heads, k_ln.weight) |
| 529 | + b = _expand_params(kv_n_heads, k_ln.bias) |
| 530 | + if isinstance(k_ln, LPLayerNorm): |
| 531 | + key = low_precision_groupnorm(key, kv_n_heads, w, b, |
| 532 | + eps=k_ln.eps).to(dtype) |
| 533 | + elif isinstance(k_ln, nn.LayerNorm): |
| 534 | + key = nn.functional.group_norm(key, kv_n_heads, w, b, eps=k_ln.eps) |
| 535 | + else: |
| 536 | + raise ValueError( |
| 537 | + f'qk_gn not applicable for given k_ln type ({type(k_ln)=}).') |
| 538 | + |
| 539 | + return query, key |
| 540 | + |
| 541 | + |
501 | 542 | class GroupedQueryAttention(nn.Module):
|
502 | 543 | """Grouped Query Attention (GQA) is a generalization of Multi-head (MHA).
|
503 | 544 |
|
@@ -629,16 +670,22 @@ def forward(
|
629 | 670 |
|
630 | 671 | key_padding_mask = attention_mask
|
631 | 672 |
|
632 |
| - if self.qk_ln or self.qk_gn: |
| 673 | + if self.qk_gn: |
| 674 | + # Applying groupnorm to qk |
| 675 | + query, key = _apply_qk_gn( |
| 676 | + query, |
| 677 | + key, |
| 678 | + self.n_heads, |
| 679 | + self.kv_n_heads, |
| 680 | + self.q_ln, |
| 681 | + self.k_ln, |
| 682 | + ) |
| 683 | + |
| 684 | + if self.qk_ln: |
633 | 685 | # Applying layernorm to qk
|
634 |
| - q_shape, k_shape = query.shape, key.shape |
635 |
| - if self.qk_gn: |
636 |
| - b, s = query.shape[:2] |
637 |
| - query = query.view(b, s, self.n_heads, -1) |
638 |
| - key = key.view(b, s, self.kv_n_heads, -1) |
639 | 686 | dtype = query.dtype
|
640 |
| - query = self.q_ln(query).to(dtype).view(q_shape) |
641 |
| - key = self.k_ln(key).to(dtype).view(k_shape) |
| 687 | + query = self.q_ln(query).to(dtype) |
| 688 | + key = self.k_ln(key).to(dtype) |
642 | 689 |
|
643 | 690 | if rotary_emb_w_meta_info is not None:
|
644 | 691 | rotary_emb = rotary_emb_w_meta_info['rotary_emb']
|
|
0 commit comments