Skip to content

Commit bca1c33

Browse files
author
root
committed
updt to include low precision groupnorm;
1 parent 6c0472b commit bca1c33

File tree

2 files changed

+80
-10
lines changed

2 files changed

+80
-10
lines changed

llmfoundry/models/layers/attention.py

Lines changed: 57 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212
import transformers
1313
from einops import rearrange
1414
from packaging import version
15-
from torch import nn
1615

1716
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
18-
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY
17+
from llmfoundry.models.layers.norm import (NORM_CLASS_REGISTRY, LPLayerNorm,
18+
low_precision_groupnorm)
1919

2020

2121
def is_flash_v2_installed(v2_version: str = '2.0.0'):
@@ -498,6 +498,47 @@ def triton_flash_attn_fn(
498498
return output, None, past_key_value
499499

500500

501+
def _expand_params(heads: int, param: Optional[torch.Tensor] = None):
502+
if param is None:
503+
return None
504+
return param.repeat(heads)
505+
506+
507+
def _apply_qk_gn(
508+
query: torch.Tensor,
509+
key: torch.Tensor,
510+
n_heads: int,
511+
kv_n_heads: int,
512+
q_ln: nn.Module,
513+
k_ln: nn.Module,
514+
):
515+
dtype = query.dtype
516+
517+
w = _expand_params(n_heads, q_ln.weight)
518+
b = _expand_params(n_heads, q_ln.bias)
519+
if isinstance(q_ln, LPLayerNorm):
520+
query = low_precision_groupnorm(query, n_heads, w, b,
521+
eps=q_ln.eps).to(dtype)
522+
elif isinstance(q_ln, nn.LayerNorm):
523+
query = nn.functional.group_norm(query, n_heads, w, b, eps=q_ln.eps)
524+
else:
525+
raise ValueError(
526+
f'qk_gn not applicable for given q_ln type ({type(q_ln)=}).')
527+
528+
w = _expand_params(kv_n_heads, k_ln.weight)
529+
b = _expand_params(kv_n_heads, k_ln.bias)
530+
if isinstance(k_ln, LPLayerNorm):
531+
key = low_precision_groupnorm(key, kv_n_heads, w, b,
532+
eps=k_ln.eps).to(dtype)
533+
elif isinstance(k_ln, nn.LayerNorm):
534+
key = nn.functional.group_norm(key, kv_n_heads, w, b, eps=k_ln.eps)
535+
else:
536+
raise ValueError(
537+
f'qk_gn not applicable for given k_ln type ({type(k_ln)=}).')
538+
539+
return query, key
540+
541+
501542
class GroupedQueryAttention(nn.Module):
502543
"""Grouped Query Attention (GQA) is a generalization of Multi-head (MHA).
503544
@@ -629,16 +670,22 @@ def forward(
629670

630671
key_padding_mask = attention_mask
631672

632-
if self.qk_ln or self.qk_gn:
673+
if self.qk_gn:
674+
# Applying groupnorm to qk
675+
query, key = _apply_qk_gn(
676+
query,
677+
key,
678+
self.n_heads,
679+
self.kv_n_heads,
680+
self.q_ln,
681+
self.k_ln,
682+
)
683+
684+
if self.qk_ln:
633685
# Applying layernorm to qk
634-
q_shape, k_shape = query.shape, key.shape
635-
if self.qk_gn:
636-
b, s = query.shape[:2]
637-
query = query.view(b, s, self.n_heads, -1)
638-
key = key.view(b, s, self.kv_n_heads, -1)
639686
dtype = query.dtype
640-
query = self.q_ln(query).to(dtype).view(q_shape)
641-
key = self.k_ln(key).to(dtype).view(k_shape)
687+
query = self.q_ln(query).to(dtype)
688+
key = self.k_ln(key).to(dtype)
642689

643690
if rotary_emb_w_meta_info is not None:
644691
rotary_emb = rotary_emb_w_meta_info['rotary_emb']

llmfoundry/models/layers/norm.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,29 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
5353
)
5454

5555

56+
def low_precision_groupnorm(
57+
x: torch.Tensor,
58+
groups: int,
59+
weight: Optional[torch.Tensor] = None,
60+
bias: Optional[torch.Tensor] = None,
61+
eps: float = 1e-05,
62+
):
63+
device = x.device
64+
downcast_x = _cast_if_autocast_enabled(x)
65+
downcast_weight = _cast_if_autocast_enabled(
66+
weight) if weight is not None else weight
67+
downcast_bias = _cast_if_autocast_enabled(
68+
bias) if bias is not None else bias
69+
with torch.autocast(enabled=False, device_type=device.type):
70+
return torch.nn.functional.group_norm(
71+
downcast_x,
72+
groups,
73+
downcast_weight,
74+
downcast_bias,
75+
eps,
76+
)
77+
78+
5679
def rms_norm(x: torch.Tensor,
5780
weight: Optional[torch.Tensor] = None,
5881
eps: float = 1e-5) -> torch.Tensor:

0 commit comments

Comments
 (0)