Skip to content

Commit 6c0472b

Browse files
authored
Merge branch 'main' into qkgn
2 parents 3610525 + 19368c6 commit 6c0472b

15 files changed

+503
-300
lines changed

.ci/FILE_HEADER

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
Copyright 2022 MosaicML LLM Foundry authors
1+
Copyright 2024 MosaicML LLM Foundry authors
22
SPDX-License-Identifier: Apache-2.0

.pre-commit-config.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,15 @@ repos:
5757
- id: mixed-line-ending
5858
- id: trailing-whitespace
5959
- repo: https://github.com/Lucas-C/pre-commit-hooks
60-
rev: v1.3.1
60+
rev: v1.5.4
6161
hooks:
6262
- id: insert-license
6363
args:
6464
- --license-filepath
6565
- .ci/FILE_HEADER
6666
- --comment-style
6767
- '#'
68+
- --allow-past-years
6869
types: [python]
6970
- repo: https://github.com/PyCQA/docformatter
7071
rev: v1.5.0

llmfoundry/models/layers/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
from llmfoundry.models.layers.attention import (
5-
ATTN_CLASS_REGISTRY, MultiheadAttention, MultiQueryAttention,
6-
attn_bias_shape, build_alibi_bias, build_attn_bias, flash_attn_fn,
7-
scaled_multihead_dot_product_attention, triton_flash_attn_fn)
5+
ATTN_CLASS_REGISTRY, GroupedQueryAttention, MultiheadAttention,
6+
MultiQueryAttention, attn_bias_shape, build_alibi_bias, build_attn_bias,
7+
flash_attn_fn, scaled_multihead_dot_product_attention, triton_flash_attn_fn)
88
from llmfoundry.models.layers.blocks import MPTBlock
99
from llmfoundry.models.layers.custom_embedding import SharedEmbedding
1010
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
@@ -17,6 +17,7 @@
1717
'triton_flash_attn_fn',
1818
'MultiheadAttention',
1919
'MultiQueryAttention',
20+
'GroupedQueryAttention',
2021
'attn_bias_shape',
2122
'build_attn_bias',
2223
'build_alibi_bias',

llmfoundry/models/layers/attention.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -228,12 +228,17 @@ def flash_attn_fn(
228228
training: bool = False,
229229
needs_weights: bool = False,
230230
multiquery: bool = False,
231-
attention_mask_in_length: Optional[torch.Tensor] = None,
232231
should_repeat_kv_for_gqa: Optional[bool] = True,
233232
sliding_window_size: int = -1,
234233
alibi_slopes: Optional[torch.Tensor] = None,
234+
flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None,
235235
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor,
236236
torch.Tensor]]]:
237+
if key_padding_mask is not None:
238+
raise ValueError('key_padding_mask should be None for flash attn.')
239+
del key_padding_mask
240+
if flash_attn_padding_info is None:
241+
raise ValueError('flash_attn_padding_info is required for flash attn.')
237242
try:
238243
from flash_attn import bert_padding, flash_attn_interface # type: ignore # yapf: disable # isort: skip
239244
except:
@@ -267,25 +272,24 @@ def flash_attn_fn(
267272

268273
batch_size, seqlen = query.shape[:2]
269274

270-
if attention_mask_in_length is None:
271-
if key_padding_mask is None:
272-
key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
273-
query_padding_mask = key_padding_mask[:, -query.size(1):]
274-
unpadding_function = bert_padding.unpad_input
275-
else:
276-
key_padding_mask = attention_mask_in_length
277-
query_padding_mask = attention_mask_in_length
278-
unpadding_function = bert_padding.unpad_input_for_concatenated_sequences
275+
indices_q = flash_attn_padding_info['indices_q']
276+
indices_k = flash_attn_padding_info['indices_k']
277+
indices_v = flash_attn_padding_info['indices_v']
278+
cu_seqlens_q = flash_attn_padding_info['cu_seqlens_q']
279+
cu_seqlens_k = flash_attn_padding_info['cu_seqlens_k']
280+
max_seqlen_q = flash_attn_padding_info['max_seqlen_q']
281+
max_seqlen_k = flash_attn_padding_info['max_seqlen_k']
279282

280-
query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpadding_function(
281-
query, query_padding_mask)
283+
query_unpad = bert_padding.index_first_axis(
284+
rearrange(query, 'b s ... -> (b s) ...'), indices_q)
282285
query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads)
283286

284-
key_unpad, _, cu_seqlens_k, max_seqlen_k = unpadding_function(
285-
key, key_padding_mask)
287+
key_unpad = bert_padding.index_first_axis(
288+
rearrange(key, 'b s ... -> (b s) ...'), indices_k)
286289
key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads)
287290

288-
value_unpad, _, _, _ = unpadding_function(value, key_padding_mask)
291+
value_unpad = bert_padding.index_first_axis(
292+
rearrange(value, 'b s ... -> (b s) ...'), indices_v)
289293
value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads)
290294

291295
if (kv_n_heads < n_heads) and (not is_flash_v2_installed()) and (
@@ -605,8 +609,8 @@ def forward(
605609
rotary_emb_w_meta_info: Optional[dict] = None,
606610
is_causal: bool = True,
607611
needs_weights: bool = False,
608-
attention_mask_in_length: Optional[torch.Tensor] = None,
609612
alibi_slopes: Optional[torch.Tensor] = None,
613+
flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None,
610614
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[
611615
torch.Tensor, torch.Tensor]]]:
612616
qkv = self.Wqkv(x)
@@ -677,11 +681,12 @@ def forward(
677681

678682
extra_attn_kwargs = {}
679683
if self.attn_impl == 'flash':
684+
key_padding_mask = None
680685
extra_attn_kwargs = {
681-
'attention_mask_in_length': attention_mask_in_length,
682686
'should_repeat_kv_for_gqa': not is_flash_v2_installed(),
683687
'sliding_window_size': self.sliding_window_size,
684688
'alibi_slopes': alibi_slopes,
689+
'flash_attn_padding_info': flash_attn_padding_info,
685690
}
686691

687692
context, attn_weights, past_key_value = self.attn_fn(

llmfoundry/models/layers/blocks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,8 @@ def forward(
123123
attention_mask: Optional[torch.ByteTensor] = None,
124124
is_causal: bool = True,
125125
output_attentions: bool = False,
126-
attention_mask_in_length: Optional[torch.Tensor] = None,
127126
alibi_slopes: Optional[torch.Tensor] = None,
127+
flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None,
128128
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[
129129
torch.Tensor, torch.Tensor]]]:
130130
a = self.norm_1(x)
@@ -136,8 +136,8 @@ def forward(
136136
attention_mask=attention_mask,
137137
is_causal=is_causal,
138138
needs_weights=output_attentions,
139-
attention_mask_in_length=attention_mask_in_length,
140139
alibi_slopes=alibi_slopes,
140+
flash_attn_padding_info=flash_attn_padding_info,
141141
)
142142
x = x + self.resid_attn_dropout(b)
143143
m = x

llmfoundry/models/mpt/modeling_mpt.py

Lines changed: 84 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
77
"""
88

9+
from __future__ import annotations
10+
911
import math
1012
import warnings
1113
from typing import (Any, Dict, List, Mapping, MutableMapping, Optional, Tuple,
@@ -24,15 +26,23 @@
2426
from composer.models import HuggingFaceModel
2527
from composer.utils import dist
2628

27-
from llmfoundry.models.layers.attention import is_flash_v2_installed
29+
from llmfoundry.models.layers.attention import (is_flash_v1_installed,
30+
is_flash_v2_installed)
2831

2932
if is_flash_v2_installed():
3033
try: # This try...except is needed because transformers requires it despite the 'if' statement above
34+
from flash_attn import bert_padding
3135
from flash_attn.layers.rotary import \
3236
RotaryEmbedding as DAILRotaryEmbedding
3337
except Exception as e:
3438
raise e
3539

40+
if is_flash_v1_installed():
41+
try: # This try...except is needed because transformers requires it despite the 'if' statement above
42+
from flash_attn import bert_padding
43+
except Exception as e:
44+
raise e
45+
3646
from omegaconf import DictConfig
3747
from omegaconf import OmegaConf as om
3848
from transformers import PreTrainedModel, PreTrainedTokenizerBase
@@ -216,6 +226,44 @@ def gen_attention_mask_in_length(sequence_id: Union[None, torch.Tensor], S: int,
216226
return attention_mask_in_length
217227

218228

229+
def gen_flash_attn_padding_info(
230+
bsz: int,
231+
S: int,
232+
past_key_len: int,
233+
device: torch.device,
234+
attention_mask_in_length: Optional[torch.Tensor] = None,
235+
attention_mask: Optional[torch.Tensor] = None):
236+
flash_attn_padding_info = {}
237+
if attention_mask_in_length is None:
238+
key_padding_mask = attention_mask
239+
if key_padding_mask is None:
240+
key_padding_mask = torch.ones((bsz, past_key_len + S),
241+
dtype=torch.bool,
242+
device=device)
243+
query_padding_mask = key_padding_mask[:, -S:]
244+
unpadding_function = bert_padding.unpad_input
245+
else:
246+
key_padding_mask = attention_mask_in_length
247+
query_padding_mask = attention_mask_in_length
248+
unpadding_function = bert_padding.unpad_input_for_concatenated_sequences
249+
250+
_, indices_q, cu_seqlens_q, max_seqlen_q = unpadding_function(
251+
torch.empty(bsz, S, 1, device=device), query_padding_mask)
252+
_, indices_k, cu_seqlens_k, max_seqlen_k = unpadding_function(
253+
torch.empty(bsz, past_key_len + S, 1, device=device), key_padding_mask)
254+
_, indices_v, _, _ = unpadding_function(
255+
torch.empty(bsz, past_key_len + S, 1, device=device), key_padding_mask)
256+
257+
flash_attn_padding_info['indices_q'] = indices_q
258+
flash_attn_padding_info['indices_k'] = indices_k
259+
flash_attn_padding_info['indices_v'] = indices_v
260+
flash_attn_padding_info['cu_seqlens_q'] = cu_seqlens_q
261+
flash_attn_padding_info['cu_seqlens_k'] = cu_seqlens_k
262+
flash_attn_padding_info['max_seqlen_q'] = max_seqlen_q
263+
flash_attn_padding_info['max_seqlen_k'] = max_seqlen_k
264+
return flash_attn_padding_info
265+
266+
219267
def apply_sequence_id(attn_bias: torch.Tensor, sequence_id: torch.LongTensor,
220268
max_seq_len: int) -> torch.Tensor:
221269
seq_len = sequence_id.shape[-1]
@@ -246,6 +294,14 @@ class MPTPreTrainedModel(PreTrainedModel):
246294
_no_split_modules = ['MPTBlock']
247295

248296

297+
def _fsdp_wrap_fn(
298+
self: Union[MPTModel, MPTForCausalLM],
299+
module: nn.Module,
300+
) -> bool:
301+
# FSDP Wrap function for MPT Models
302+
return isinstance(module, MPTBlock)
303+
304+
249305
class MPTModel(MPTPreTrainedModel):
250306

251307
def __init__(self, config: MPTConfig):
@@ -515,10 +571,12 @@ def forward(
515571
raise ValueError(
516572
'You cannot specify both input_ids and inputs_embeds.')
517573
elif input_ids is not None:
574+
bsz = input_ids.size(0)
518575
S = input_ids.size(1)
519576
x = self.wte(input_ids)
520577
input_device = input_ids.device
521578
elif inputs_embeds is not None:
579+
bsz = inputs_embeds.size(0)
522580
S = inputs_embeds.size(1)
523581
x = inputs_embeds
524582
input_device = inputs_embeds.device
@@ -530,22 +588,23 @@ def forward(
530588
), f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'
531589

532590
rotary_emb_w_meta_info = None
533-
if self.learned_pos_emb or self.rope:
534-
past_position = 0
535-
if past_key_values is not None:
536-
if len(past_key_values) != self.config.n_layers:
537-
raise ValueError(
538-
f'past_key_values must provide a past_key_value for each attention '
539-
+
540-
f'layer in the network ({len(past_key_values)=}; {self.config.n_layers=}).'
541-
)
542-
# For attn_impl: triton and flash the past key tensor spec is (batch, seq, dim).
543-
# For attn_impl: torch the past key tensor spec is (batch, heads, head_dim, seq).
544-
# Here we shift position embedding using the `seq` dim of the past key
545-
past_position = past_key_values[0][0].size(1)
546-
if self.attn_impl == 'torch':
547-
past_position = past_key_values[0][0].size(3)
548591

592+
past_position = 0
593+
if past_key_values is not None:
594+
if len(past_key_values) != self.config.n_layers:
595+
raise ValueError(
596+
f'past_key_values must provide a past_key_value for each attention '
597+
+
598+
f'layer in the network ({len(past_key_values)=}; {self.config.n_layers=}).'
599+
)
600+
# For attn_impl: triton and flash the past key tensor spec is (batch, seq, dim).
601+
# For attn_impl: torch the past key tensor spec is (batch, heads, head_dim, seq).
602+
# Here we shift position embedding using the `seq` dim of the past key
603+
past_position = past_key_values[0][0].size(1)
604+
if self.attn_impl == 'torch':
605+
past_position = past_key_values[0][0].size(3)
606+
607+
if self.learned_pos_emb or self.rope:
549608
if self.learned_pos_emb and (S + past_position >
550609
self.config.max_seq_len):
551610
raise ValueError(
@@ -623,6 +682,12 @@ def forward(
623682

624683
all_hidden_states = () if output_hidden_states else None
625684
all_self_attns = () if output_attentions else None
685+
flash_attn_padding_info = {}
686+
if self.attn_impl == 'flash':
687+
flash_attn_padding_info = gen_flash_attn_padding_info(
688+
bsz, S, past_position, x.device, attention_mask_in_length,
689+
attention_mask)
690+
626691
for b_idx, block in enumerate(self.blocks):
627692
if output_hidden_states:
628693
assert all_hidden_states is not None # pyright
@@ -637,8 +702,8 @@ def forward(
637702
attention_mask=attention_mask,
638703
is_causal=self.is_causal,
639704
output_attentions=bool(output_attentions),
640-
attention_mask_in_length=attention_mask_in_length,
641705
alibi_slopes=alibi_slopes,
706+
flash_attn_padding_info=flash_attn_padding_info,
642707
)
643708
if presents is not None:
644709
presents += (present,)
@@ -673,7 +738,7 @@ def param_init_fn(self, module: nn.Module) -> None:
673738

674739
# FSDP Wrap function
675740
def fsdp_wrap_fn(self, module: nn.Module) -> bool:
676-
return isinstance(module, MPTBlock)
741+
return _fsdp_wrap_fn(self, module)
677742

678743
# Activation Checkpointing
679744
def activation_checkpointing_fn(self, module: nn.Module) -> bool:
@@ -834,7 +899,7 @@ def param_init_fn(self, module: nn.Module) -> None:
834899

835900
# FSDP Wrap function
836901
def fsdp_wrap_fn(self, module: nn.Module) -> bool:
837-
return isinstance(module, MPTBlock)
902+
return _fsdp_wrap_fn(self, module)
838903

839904
# Activation Checkpointing
840905
def activation_checkpointing_fn(self, module: nn.Module) -> bool:

llmfoundry/utils/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
log_config, pop_config,
1313
update_batch_size_info)
1414
from llmfoundry.utils.model_download_utils import (
15-
download_from_cache_server, download_from_hf_hub)
15+
download_from_hf_hub, download_from_http_fileserver)
1616
except ImportError as e:
1717
raise ImportError(
1818
'Please make sure to pip install . to get requirements for llm-foundry.'
@@ -28,7 +28,7 @@
2828
'build_tokenizer',
2929
'calculate_batch_size_info',
3030
'convert_and_save_ft_weights',
31-
'download_from_cache_server',
31+
'download_from_http_fileserver',
3232
'download_from_hf_hub',
3333
'get_hf_tokenizer_from_composer_state_dict',
3434
'update_batch_size_info',

0 commit comments

Comments
 (0)