Skip to content

Conversation

@deepllz
Copy link
Collaborator

@deepllz deepllz commented Jun 3, 2024

PR types

Performance optimization

PR changes

Models

Description

千问模型增加fuse_attention_ffn的支持

@paddle-bot
Copy link

paddle-bot bot commented Jun 3, 2024

Thanks for your contribution!

@deepllz deepllz force-pushed the qwen branch 3 times, most recently from 5139462 to 18b5946 Compare June 3, 2024 10:04
@codecov
Copy link

codecov bot commented Jun 3, 2024

Codecov Report

Attention: Patch coverage is 50.00000% with 14 lines in your changes missing coverage. Please review.

Project coverage is 53.86%. Comparing base (87edf28) to head (95c1780).
Report is 242 commits behind head on develop.

Files with missing lines Patch % Lines
paddlenlp/transformers/qwen/modeling.py 46.15% 14 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #8526      +/-   ##
===========================================
- Coverage    53.87%   53.86%   -0.01%     
===========================================
  Files          620      620              
  Lines        97081    97101      +20     
===========================================
+ Hits         52299    52308       +9     
- Misses       44782    44793      +11     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

self.w1 = nn.Linear(config.hidden_size, ff_dim_in, bias_attr=not config.no_bias)
self.w2 = nn.Linear(config.hidden_size, ff_dim_in, bias_attr=not config.no_bias)
if self.fuse_attention_ffn:
self.gate_up_fused_proj = nn.Linear(config.hidden_size, ff_dim_in * 2, bias_attr=not config.no_bias)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def get_tensor_parallel_split_mappings(num_hidden_layers):
final_actions = {}
base_actions = {
# Column Linear
"lm_head.weight": partial(fn, is_column=True),
"qwen.h.0.mlp.w2.weight": partial(fn, is_column=True),
"qwen.h.0.mlp.w1.weight": partial(fn, is_column=True),
"qwen.h.0.attn.c_attn.weight": partial(fn, is_column=True, is_naive_3fuse=True),
"qwen.h.0.attn.c_attn.bias": partial(fn, is_column=True, is_naive_3fuse=True),
# Row Linear
"qwen.wte.weight": partial(fn, is_column=False),
"qwen.h.0.mlp.c_proj.weight": partial(fn, is_column=False),
"qwen.h.0.attn.c_proj.weight": partial(fn, is_column=False),
}
for key, action in base_actions.items():
if "h.0." in key:
for i in range(num_hidden_layers):
final_actions[key.replace("h.0.", f"h.{i}.")] = action
final_actions[key] = action
return final_actions
mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers)

你要适配一下切分规则,tensor parallel 对 gate_up_fused_proj 的切分规则

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好哒,后续还有一个支持sp的pr,在下一个pr一起补充一下。

@wawltor wawltor merged commit d06b327 into PaddlePaddle:develop Jun 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants