6
6
Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
7
7
"""
8
8
9
+ from __future__ import annotations
10
+
9
11
import math
10
12
import warnings
11
13
from typing import (Any , Dict , List , Mapping , MutableMapping , Optional , Tuple ,
24
26
from composer .models import HuggingFaceModel
25
27
from composer .utils import dist
26
28
27
- from llmfoundry .models .layers .attention import is_flash_v2_installed
29
+ from llmfoundry .models .layers .attention import (is_flash_v1_installed ,
30
+ is_flash_v2_installed )
28
31
29
32
if is_flash_v2_installed ():
30
33
try : # This try...except is needed because transformers requires it despite the 'if' statement above
34
+ from flash_attn import bert_padding
31
35
from flash_attn .layers .rotary import \
32
36
RotaryEmbedding as DAILRotaryEmbedding
33
37
except Exception as e :
34
38
raise e
35
39
40
+ if is_flash_v1_installed ():
41
+ try : # This try...except is needed because transformers requires it despite the 'if' statement above
42
+ from flash_attn import bert_padding
43
+ except Exception as e :
44
+ raise e
45
+
36
46
from omegaconf import DictConfig
37
47
from omegaconf import OmegaConf as om
38
48
from transformers import PreTrainedModel , PreTrainedTokenizerBase
@@ -216,6 +226,44 @@ def gen_attention_mask_in_length(sequence_id: Union[None, torch.Tensor], S: int,
216
226
return attention_mask_in_length
217
227
218
228
229
+ def gen_flash_attn_padding_info (
230
+ bsz : int ,
231
+ S : int ,
232
+ past_key_len : int ,
233
+ device : torch .device ,
234
+ attention_mask_in_length : Optional [torch .Tensor ] = None ,
235
+ attention_mask : Optional [torch .Tensor ] = None ):
236
+ flash_attn_padding_info = {}
237
+ if attention_mask_in_length is None :
238
+ key_padding_mask = attention_mask
239
+ if key_padding_mask is None :
240
+ key_padding_mask = torch .ones ((bsz , past_key_len + S ),
241
+ dtype = torch .bool ,
242
+ device = device )
243
+ query_padding_mask = key_padding_mask [:, - S :]
244
+ unpadding_function = bert_padding .unpad_input
245
+ else :
246
+ key_padding_mask = attention_mask_in_length
247
+ query_padding_mask = attention_mask_in_length
248
+ unpadding_function = bert_padding .unpad_input_for_concatenated_sequences
249
+
250
+ _ , indices_q , cu_seqlens_q , max_seqlen_q = unpadding_function (
251
+ torch .empty (bsz , S , 1 , device = device ), query_padding_mask )
252
+ _ , indices_k , cu_seqlens_k , max_seqlen_k = unpadding_function (
253
+ torch .empty (bsz , past_key_len + S , 1 , device = device ), key_padding_mask )
254
+ _ , indices_v , _ , _ = unpadding_function (
255
+ torch .empty (bsz , past_key_len + S , 1 , device = device ), key_padding_mask )
256
+
257
+ flash_attn_padding_info ['indices_q' ] = indices_q
258
+ flash_attn_padding_info ['indices_k' ] = indices_k
259
+ flash_attn_padding_info ['indices_v' ] = indices_v
260
+ flash_attn_padding_info ['cu_seqlens_q' ] = cu_seqlens_q
261
+ flash_attn_padding_info ['cu_seqlens_k' ] = cu_seqlens_k
262
+ flash_attn_padding_info ['max_seqlen_q' ] = max_seqlen_q
263
+ flash_attn_padding_info ['max_seqlen_k' ] = max_seqlen_k
264
+ return flash_attn_padding_info
265
+
266
+
219
267
def apply_sequence_id (attn_bias : torch .Tensor , sequence_id : torch .LongTensor ,
220
268
max_seq_len : int ) -> torch .Tensor :
221
269
seq_len = sequence_id .shape [- 1 ]
@@ -246,6 +294,14 @@ class MPTPreTrainedModel(PreTrainedModel):
246
294
_no_split_modules = ['MPTBlock' ]
247
295
248
296
297
+ def _fsdp_wrap_fn (
298
+ self : Union [MPTModel , MPTForCausalLM ],
299
+ module : nn .Module ,
300
+ ) -> bool :
301
+ # FSDP Wrap function for MPT Models
302
+ return isinstance (module , MPTBlock )
303
+
304
+
249
305
class MPTModel (MPTPreTrainedModel ):
250
306
251
307
def __init__ (self , config : MPTConfig ):
@@ -515,10 +571,12 @@ def forward(
515
571
raise ValueError (
516
572
'You cannot specify both input_ids and inputs_embeds.' )
517
573
elif input_ids is not None :
574
+ bsz = input_ids .size (0 )
518
575
S = input_ids .size (1 )
519
576
x = self .wte (input_ids )
520
577
input_device = input_ids .device
521
578
elif inputs_embeds is not None :
579
+ bsz = inputs_embeds .size (0 )
522
580
S = inputs_embeds .size (1 )
523
581
x = inputs_embeds
524
582
input_device = inputs_embeds .device
@@ -530,22 +588,23 @@ def forward(
530
588
), f'Cannot forward input with seq_len={ S } , this model only supports seq_len<={ self .config .max_seq_len } '
531
589
532
590
rotary_emb_w_meta_info = None
533
- if self .learned_pos_emb or self .rope :
534
- past_position = 0
535
- if past_key_values is not None :
536
- if len (past_key_values ) != self .config .n_layers :
537
- raise ValueError (
538
- f'past_key_values must provide a past_key_value for each attention '
539
- +
540
- f'layer in the network ({ len (past_key_values )= } ; { self .config .n_layers = } ).'
541
- )
542
- # For attn_impl: triton and flash the past key tensor spec is (batch, seq, dim).
543
- # For attn_impl: torch the past key tensor spec is (batch, heads, head_dim, seq).
544
- # Here we shift position embedding using the `seq` dim of the past key
545
- past_position = past_key_values [0 ][0 ].size (1 )
546
- if self .attn_impl == 'torch' :
547
- past_position = past_key_values [0 ][0 ].size (3 )
548
591
592
+ past_position = 0
593
+ if past_key_values is not None :
594
+ if len (past_key_values ) != self .config .n_layers :
595
+ raise ValueError (
596
+ f'past_key_values must provide a past_key_value for each attention '
597
+ +
598
+ f'layer in the network ({ len (past_key_values )= } ; { self .config .n_layers = } ).'
599
+ )
600
+ # For attn_impl: triton and flash the past key tensor spec is (batch, seq, dim).
601
+ # For attn_impl: torch the past key tensor spec is (batch, heads, head_dim, seq).
602
+ # Here we shift position embedding using the `seq` dim of the past key
603
+ past_position = past_key_values [0 ][0 ].size (1 )
604
+ if self .attn_impl == 'torch' :
605
+ past_position = past_key_values [0 ][0 ].size (3 )
606
+
607
+ if self .learned_pos_emb or self .rope :
549
608
if self .learned_pos_emb and (S + past_position >
550
609
self .config .max_seq_len ):
551
610
raise ValueError (
@@ -623,6 +682,12 @@ def forward(
623
682
624
683
all_hidden_states = () if output_hidden_states else None
625
684
all_self_attns = () if output_attentions else None
685
+ flash_attn_padding_info = {}
686
+ if self .attn_impl == 'flash' :
687
+ flash_attn_padding_info = gen_flash_attn_padding_info (
688
+ bsz , S , past_position , x .device , attention_mask_in_length ,
689
+ attention_mask )
690
+
626
691
for b_idx , block in enumerate (self .blocks ):
627
692
if output_hidden_states :
628
693
assert all_hidden_states is not None # pyright
@@ -637,8 +702,8 @@ def forward(
637
702
attention_mask = attention_mask ,
638
703
is_causal = self .is_causal ,
639
704
output_attentions = bool (output_attentions ),
640
- attention_mask_in_length = attention_mask_in_length ,
641
705
alibi_slopes = alibi_slopes ,
706
+ flash_attn_padding_info = flash_attn_padding_info ,
642
707
)
643
708
if presents is not None :
644
709
presents += (present ,)
@@ -673,7 +738,7 @@ def param_init_fn(self, module: nn.Module) -> None:
673
738
674
739
# FSDP Wrap function
675
740
def fsdp_wrap_fn (self , module : nn .Module ) -> bool :
676
- return isinstance ( module , MPTBlock )
741
+ return _fsdp_wrap_fn ( self , module )
677
742
678
743
# Activation Checkpointing
679
744
def activation_checkpointing_fn (self , module : nn .Module ) -> bool :
@@ -834,7 +899,7 @@ def param_init_fn(self, module: nn.Module) -> None:
834
899
835
900
# FSDP Wrap function
836
901
def fsdp_wrap_fn (self , module : nn .Module ) -> bool :
837
- return isinstance ( module , MPTBlock )
902
+ return _fsdp_wrap_fn ( self , module )
838
903
839
904
# Activation Checkpointing
840
905
def activation_checkpointing_fn (self , module : nn .Module ) -> bool :
0 commit comments