Skip to content

Commit c79b7e2

Browse files
committed
add annotation for the fusion
1 parent 896c01f commit c79b7e2

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

paddlenlp/transformers/llama/modeling.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -774,6 +774,17 @@ def forward(
774774

775775
if self.fuse_attention_qkv:
776776
mix_layer = self.qkv_proj(hidden_states)
777+
# NOTE for GQA attention fusion (compatible with MHA and MQA):
778+
# The weight for qkv_proj is in shape like [hidden_size, hidden_size + 2 * num_kv_heads * head_dim].
779+
# After the projection, the mix_layer is in shape like [b, s, hidden_size + 2 * num_kv_heads * head_dim].
780+
# Reshape the mix_layer into a shape like [b, s, num_kv_heads, (num_groups + 2) * head_dim],
781+
# where num_groups = num_q_heads // num_kv_heads.
782+
# Split the mix_layer on the last axis into three sections [num_groups * head_dim, head_dim, head_dim]
783+
# to represent the q, k and v respectively.
784+
# The q is in the shape like [b, s, num_kv_heads, num_groups * head_dim].
785+
# The k and v are in the shape like [b, s, num_kv_heads, head_dim].
786+
# Under MHA, the q is ready for the following calculation since num_kv_heads == num_q_heads,
787+
# But for the GQA or MQA, q should be reshaped into [b, s, num_q_heads, head_dim].
777788
if self.reshard_layer is not None:
778789
if self.sequence_parallel:
779790
assert self.seq_length % self.config.sep_parallel_degree == 0

0 commit comments

Comments
 (0)