Skip to content

Commit 2b62d5e

Browse files
author
root
committed
perf improvement
1 parent bca1c33 commit 2b62d5e

File tree

2 files changed

+29
-10
lines changed

2 files changed

+29
-10
lines changed

llmfoundry/models/layers/attention.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -514,23 +514,29 @@ def _apply_qk_gn(
514514
):
515515
dtype = query.dtype
516516

517-
w = _expand_params(n_heads, q_ln.weight)
518-
b = _expand_params(n_heads, q_ln.bias)
519517
if isinstance(q_ln, LPLayerNorm):
520-
query = low_precision_groupnorm(query, n_heads, w, b,
518+
query = low_precision_groupnorm(query,
519+
n_heads,
520+
q_ln.weight,
521+
q_ln.bias,
521522
eps=q_ln.eps).to(dtype)
522523
elif isinstance(q_ln, nn.LayerNorm):
524+
w = _expand_params(n_heads, q_ln.weight)
525+
b = _expand_params(n_heads, q_ln.bias)
523526
query = nn.functional.group_norm(query, n_heads, w, b, eps=q_ln.eps)
524527
else:
525528
raise ValueError(
526529
f'qk_gn not applicable for given q_ln type ({type(q_ln)=}).')
527530

528-
w = _expand_params(kv_n_heads, k_ln.weight)
529-
b = _expand_params(kv_n_heads, k_ln.bias)
530531
if isinstance(k_ln, LPLayerNorm):
531-
key = low_precision_groupnorm(key, kv_n_heads, w, b,
532+
key = low_precision_groupnorm(key,
533+
kv_n_heads,
534+
k_ln.weight,
535+
k_ln.bias,
532536
eps=k_ln.eps).to(dtype)
533537
elif isinstance(k_ln, nn.LayerNorm):
538+
w = _expand_params(kv_n_heads, k_ln.weight)
539+
b = _expand_params(kv_n_heads, k_ln.bias)
534540
key = nn.functional.group_norm(key, kv_n_heads, w, b, eps=k_ln.eps)
535541
else:
536542
raise ValueError(

llmfoundry/models/layers/norm.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
5353
)
5454

5555

56+
def _expand_params(x: torch.Tensor, param: Optional[torch.Tensor] = None):
57+
# repeat param if params are applied per group
58+
if param is None:
59+
return None
60+
if x.shape[-1] == param.shape[-1]:
61+
return param
62+
return param.repeat(x.shape[-1] // param.shape[-1])
63+
64+
5665
def low_precision_groupnorm(
5766
x: torch.Tensor,
5867
groups: int,
@@ -62,10 +71,14 @@ def low_precision_groupnorm(
6271
):
6372
device = x.device
6473
downcast_x = _cast_if_autocast_enabled(x)
65-
downcast_weight = _cast_if_autocast_enabled(
66-
weight) if weight is not None else weight
67-
downcast_bias = _cast_if_autocast_enabled(
68-
bias) if bias is not None else bias
74+
downcast_weight, downcast_bias = None, None
75+
if weight is not None:
76+
downcast_weight = _cast_if_autocast_enabled(weight)
77+
downcast_weight = _expand_params(x, downcast_weight)
78+
if bias is not None:
79+
downcast_bias = _cast_if_autocast_enabled(bias)
80+
downcast_bias = _expand_params(x, downcast_bias)
81+
6982
with torch.autocast(enabled=False, device_type=device.type):
7083
return torch.nn.functional.group_norm(
7184
downcast_x,

0 commit comments

Comments
 (0)