@@ -588,17 +588,15 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
588588 self .head_dim = self .hidden_size // config .num_attention_heads
589589
590590 self .num_key_value_heads = config .num_key_value_heads
591+ assert config .num_attention_heads // config .num_key_value_heads
591592 self .num_key_value_groups = config .num_attention_heads // config .num_key_value_heads
593+ self .gqa_or_mqa = config .num_attention_heads != config .num_key_value_heads
592594
593595 self .max_position_embeddings = config .max_position_embeddings
594596 self .seq_length = config .seq_length
595597 self .sequence_parallel = config .sequence_parallel
596598
597599 self .fuse_attention_qkv = config .fuse_attention_qkv
598- if self .fuse_attention_qkv and config .num_attention_heads != config .num_key_value_heads :
599- raise ValueError (
600- f"fuse_attention_qkv can't be True when num_attention_heads { config .num_attention_heads } != num_key_value_heads { config .num_key_value_heads } "
601- )
602600
603601 self .kv_indices = None
604602 # Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True
@@ -615,6 +613,11 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
615613 if self .num_key_value_heads % config .tensor_parallel_degree == 0 :
616614 self .num_key_value_heads = self .num_key_value_heads // config .tensor_parallel_degree
617615 else :
616+ if self .fuse_attention_qkv :
617+ # TODO(Yuang): support fusion for kv when kv heads cannot be divided by mp
618+ raise ValueError (
619+ f"fuse_attention_qkv can't be True when num_key_value_heads { config .num_key_value_heads } % tensor_parallel_degree { config .tensor_parallel_degree } != 0"
620+ )
618621 logger .warning (
619622 f"Get num_key_value_heads: { self .num_key_value_heads } , can't split to tensor_parallel_degree: { config .tensor_parallel_degree } , so we don't spilt key value weight."
620623 )
@@ -644,7 +647,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
644647 if self .fuse_attention_qkv :
645648 self .qkv_proj = ColumnParallelLinear (
646649 self .hidden_size ,
647- 3 * self .hidden_size ,
650+ self . hidden_size + 2 * self .config . num_key_value_heads * self . head_dim ,
648651 has_bias = False ,
649652 gather_output = False ,
650653 )
@@ -684,7 +687,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
684687 if self .fuse_attention_qkv :
685688 self .qkv_proj = nn .Linear (
686689 self .hidden_size ,
687- 3 * self .hidden_size ,
690+ self . hidden_size + 2 * self .config . num_key_value_heads * self . head_dim ,
688691 bias_attr = False ,
689692 )
690693 else :
@@ -771,12 +774,27 @@ def forward(
771774
772775 if self .fuse_attention_qkv :
773776 mix_layer = self .qkv_proj (hidden_states )
777+ # NOTE for GQA attention fusion (compatible with MHA and MQA):
778+ # The weight for qkv_proj is in shape like [hidden_size, hidden_size + 2 * num_kv_heads * head_dim].
779+ # After the projection, the mix_layer is in shape like [b, s, hidden_size + 2 * num_kv_heads * head_dim].
780+ # Reshape the mix_layer into a shape like [b, s, num_kv_heads, (num_groups + 2) * head_dim],
781+ # where num_groups = num_q_heads // num_kv_heads.
782+ # Split the mix_layer on the last axis into three sections [num_groups * head_dim, head_dim, head_dim]
783+ # to represent the q, k and v respectively.
784+ # The q is in the shape like [b, s, num_kv_heads, num_groups * head_dim].
785+ # The k and v are in the shape like [b, s, num_kv_heads, head_dim].
786+ # Under MHA, the q is ready for the following calculation since num_kv_heads == num_q_heads,
787+ # But for the GQA or MQA, q should be reshaped into [b, s, num_q_heads, head_dim].
774788 if self .reshard_layer is not None :
775789 if self .sequence_parallel :
776790 assert self .seq_length % self .config .sep_parallel_degree == 0
777791 mix_layer = paddle .reshape_ (
778792 mix_layer ,
779- [- 1 , self .seq_length // self .config .sep_parallel_degree , 3 * self .num_heads * self .head_dim ],
793+ [
794+ - 1 ,
795+ self .seq_length // self .config .sep_parallel_degree ,
796+ self .num_heads * self .head_dim + 2 * self .num_key_value_heads * self .head_dim ,
797+ ],
780798 )
781799 # [bs, seq_len / sep, num_head, head_dim] -> [bs, seq_len, num_head / sep, head_dim]
782800 mix_layer = self .reshard_layer (
@@ -785,15 +803,26 @@ def forward(
785803 concat_axis = 1 ,
786804 )
787805 mix_layer = paddle .reshape_ (
788- mix_layer , [0 , self .seq_length , - 1 , 3 * self .head_dim ]
806+ mix_layer , [0 , self .seq_length , - 1 , ( self . num_key_value_groups + 2 ) * self .head_dim ]
789807 ) # [bs, seq_len, num_head/k, 3*head_dim], k is sep degree
790808 else :
791809 if self .sequence_parallel :
792- target_shape = [- 1 , self .seq_length , self .num_heads , 3 * self .head_dim ]
810+ target_shape = [
811+ - 1 ,
812+ self .seq_length ,
813+ self .num_key_value_heads ,
814+ (self .num_key_value_groups + 2 ) * self .head_dim ,
815+ ]
793816 else :
794- target_shape = [0 , 0 , self .num_heads , 3 * self .head_dim ]
817+ target_shape = [0 , 0 , self .num_key_value_heads , ( self . num_key_value_groups + 2 ) * self .head_dim ]
795818 mix_layer = paddle .reshape_ (mix_layer , target_shape )
796- query_states , key_states , value_states = paddle .split (mix_layer , num_or_sections = 3 , axis = - 1 )
819+ query_states , key_states , value_states = paddle .split (
820+ mix_layer ,
821+ num_or_sections = [self .num_key_value_groups * self .head_dim , self .head_dim , self .head_dim ],
822+ axis = - 1 ,
823+ )
824+ if self .gqa_or_mqa :
825+ query_states = paddle .reshape_ (query_states , [0 , 0 , self .num_heads , self .head_dim ])
797826 else :
798827 query_states = self .q_proj (hidden_states )
799828 key_states = self .k_proj (hidden_states )
@@ -807,11 +836,19 @@ def forward(
807836 )
808837 key_states = paddle .reshape (
809838 key_states ,
810- [- 1 , self .seq_length // self .config .sep_parallel_degree , self .num_heads * self .head_dim ],
839+ [
840+ - 1 ,
841+ self .seq_length // self .config .sep_parallel_degree ,
842+ self .num_key_value_heads * self .head_dim ,
843+ ],
811844 )
812845 value_states = paddle .reshape (
813846 value_states ,
814- [- 1 , self .seq_length // self .config .sep_parallel_degree , self .num_heads * self .head_dim ],
847+ [
848+ - 1 ,
849+ self .seq_length // self .config .sep_parallel_degree ,
850+ self .num_key_value_heads * self .head_dim ,
851+ ],
815852 )
816853 query_states = self .reshard_layer (
817854 query_states ,
0 commit comments