@@ -588,17 +588,15 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
588588 self .head_dim = self .hidden_size // config .num_attention_heads
589589
590590 self .num_key_value_heads = config .num_key_value_heads
591+ assert config .num_attention_heads // config .num_key_value_heads
591592 self .num_key_value_groups = config .num_attention_heads // config .num_key_value_heads
593+ self .gqa_or_mqa = config .num_attention_heads != config .num_key_value_heads
592594
593595 self .max_position_embeddings = config .max_position_embeddings
594596 self .seq_length = config .seq_length
595597 self .sequence_parallel = config .sequence_parallel
596598
597599 self .fuse_attention_qkv = config .fuse_attention_qkv
598- if self .fuse_attention_qkv and config .num_attention_heads != config .num_key_value_heads :
599- raise ValueError (
600- f"fuse_attention_qkv can't be True when num_attention_heads { config .num_attention_heads } != num_key_value_heads { config .num_key_value_heads } "
601- )
602600
603601 self .kv_indices = None
604602 # Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True
@@ -615,6 +613,11 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
615613 if self .num_key_value_heads % config .tensor_parallel_degree == 0 :
616614 self .num_key_value_heads = self .num_key_value_heads // config .tensor_parallel_degree
617615 else :
616+ if self .fuse_attention_qkv :
617+ # TODO(Yuang): support fusion for kv when kv heads cannot be divided by mp
618+ raise ValueError (
619+ f"fuse_attention_qkv can't be True when num_key_value_heads { config .num_key_value_heads } % tensor_parallel_degree { config .tensor_parallel_degree } != 0"
620+ )
618621 logger .warning (
619622 f"Get num_key_value_heads: { self .num_key_value_heads } , can't split to tensor_parallel_degree: { config .tensor_parallel_degree } , so we don't spilt key value weight."
620623 )
@@ -644,7 +647,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
644647 if self .fuse_attention_qkv :
645648 self .qkv_proj = ColumnParallelLinear (
646649 self .hidden_size ,
647- 3 * self .hidden_size ,
650+ self . hidden_size + 2 * self .config . num_key_value_heads * self . head_dim ,
648651 has_bias = False ,
649652 gather_output = False ,
650653 )
@@ -684,7 +687,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
684687 if self .fuse_attention_qkv :
685688 self .qkv_proj = nn .Linear (
686689 self .hidden_size ,
687- 3 * self .hidden_size ,
690+ self . hidden_size + 2 * self .config . num_key_value_heads * self . head_dim ,
688691 bias_attr = False ,
689692 )
690693 else :
@@ -776,7 +779,11 @@ def forward(
776779 assert self .seq_length % self .config .sep_parallel_degree == 0
777780 mix_layer = paddle .reshape_ (
778781 mix_layer ,
779- [- 1 , self .seq_length // self .config .sep_parallel_degree , 3 * self .num_heads * self .head_dim ],
782+ [
783+ - 1 ,
784+ self .seq_length // self .config .sep_parallel_degree ,
785+ self .num_heads * self .head_dim + 2 * self .num_key_value_heads * self .head_dim ,
786+ ],
780787 )
781788 # [bs, seq_len / sep, num_head, head_dim] -> [bs, seq_len, num_head / sep, head_dim]
782789 mix_layer = self .reshard_layer (
@@ -785,15 +792,26 @@ def forward(
785792 concat_axis = 1 ,
786793 )
787794 mix_layer = paddle .reshape_ (
788- mix_layer , [0 , self .seq_length , - 1 , 3 * self .head_dim ]
795+ mix_layer , [0 , self .seq_length , - 1 , ( self . num_key_value_groups + 2 ) * self .head_dim ]
789796 ) # [bs, seq_len, num_head/k, 3*head_dim], k is sep degree
790797 else :
791798 if self .sequence_parallel :
792- target_shape = [- 1 , self .seq_length , self .num_heads , 3 * self .head_dim ]
799+ target_shape = [
800+ - 1 ,
801+ self .seq_length ,
802+ self .num_key_value_heads ,
803+ (self .num_key_value_groups + 2 ) * self .head_dim ,
804+ ]
793805 else :
794- target_shape = [0 , 0 , self .num_heads , 3 * self .head_dim ]
806+ target_shape = [0 , 0 , self .num_key_value_heads , ( self . num_key_value_groups + 2 ) * self .head_dim ]
795807 mix_layer = paddle .reshape_ (mix_layer , target_shape )
796- query_states , key_states , value_states = paddle .split (mix_layer , num_or_sections = 3 , axis = - 1 )
808+ query_states , key_states , value_states = paddle .split (
809+ mix_layer ,
810+ num_or_sections = [self .num_key_value_groups * self .head_dim , self .head_dim , self .head_dim ],
811+ axis = - 1 ,
812+ )
813+ if self .gqa_or_mqa :
814+ query_states = paddle .reshape_ (query_states , [0 , 0 , self .num_heads , self .head_dim ])
797815 else :
798816 query_states = self .q_proj (hidden_states )
799817 key_states = self .k_proj (hidden_states )
@@ -807,11 +825,19 @@ def forward(
807825 )
808826 key_states = paddle .reshape (
809827 key_states ,
810- [- 1 , self .seq_length // self .config .sep_parallel_degree , self .num_heads * self .head_dim ],
828+ [
829+ - 1 ,
830+ self .seq_length // self .config .sep_parallel_degree ,
831+ self .num_key_value_heads * self .head_dim ,
832+ ],
811833 )
812834 value_states = paddle .reshape (
813835 value_states ,
814- [- 1 , self .seq_length // self .config .sep_parallel_degree , self .num_heads * self .head_dim ],
836+ [
837+ - 1 ,
838+ self .seq_length // self .config .sep_parallel_degree ,
839+ self .num_key_value_heads * self .head_dim ,
840+ ],
815841 )
816842 query_states = self .reshard_layer (
817843 query_states ,
0 commit comments