@@ -774,6 +774,17 @@ def forward(
774774
775775 if self .fuse_attention_qkv :
776776 mix_layer = self .qkv_proj (hidden_states )
777+ # NOTE for GQA attention fusion (compatible with MHA and MQA):
778+ # The weight for qkv_proj is in shape like [hidden_size, hidden_size + 2 * num_kv_heads * head_dim].
779+ # After the projection, the mix_layer is in shape like [b, s, hidden_size + 2 * num_kv_heads * head_dim].
780+ # Reshape the mix_layer into a shape like [b, s, num_kv_heads, (num_groups + 2) * head_dim],
781+ # where num_groups = num_q_heads // num_kv_heads.
782+ # Split the mix_layer on the last axis into three sections [num_groups * head_dim, head_dim, head_dim]
783+ # to represent the q, k and v respectively.
784+ # The q is in the shape like [b, s, num_kv_heads, num_groups * head_dim].
785+ # The k and v are in the shape like [b, s, num_kv_heads, head_dim].
786+ # Under MHA, the q is ready for the following calculation since num_kv_heads == num_q_heads,
787+ # But for the GQA or MQA, q should be reshaped into [b, s, num_q_heads, head_dim].
777788 if self .reshard_layer is not None :
778789 if self .sequence_parallel :
779790 assert self .seq_length % self .config .sep_parallel_degree == 0
0 commit comments