Skip to content

Conversation

@wwbitejotunn
Copy link
Contributor

@wwbitejotunn wwbitejotunn commented Dec 29, 2022

PR types

Performance optimization

PR changes

OPs

Describe

Add nvidia tensorrt oss plugin flash attention and cross attention support to accelerate the inference speed of stable diffusion and other models.
Using flash attention and cross attention plugin, the stable diffusion latency can be speed up from 1.52s to 1.02s.
Tensorrt 8.5.2 is required to using those plugins.
Using nsys, we can see that the plugin are successful involved by unit test under trt8.5.2.2 environment

image

image

@paddle-bot
Copy link

paddle-bot bot commented Dec 29, 2022

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@wwbitejotunn wwbitejotunn marked this pull request as ready for review December 29, 2022 09:09

void TrtCrossMultiHeadMatmulFusePass::ApplyImpl(Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
#ifdef PADDLE_WITH_TENSORRT
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个if 宏是不是应该包到562行后

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我理解这里进行early stop的话是要在build fusion之前进行, 因此放在ApplyImpl的最开头.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

经过讨论后, 目前采用运行期判定的方式进行early stop, 麻烦辛苦再看一下

FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope();

#ifdef PADDLE_WITH_TENSORRT
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

经过讨论后, 目前采用运行期判定的方式进行early stop, 麻烦辛苦再看一下

MARD1NO
MARD1NO previously approved these changes Jan 6, 2023
std::get<2>(trt_version) * 10 <
8520) {
VLOG(3) << "Flash attention oss plugin only available for trt version >= "
"8.5.2.2. Stop this pass";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里只是输出了日志?应该return吧?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

看着pass有限定trt 8.5.2.2才注册,这里应该不用判断了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

考虑后还是在这里进行runtime时的early stop, 因此现在在这里加上了return

@wwbitejotunn wwbitejotunn reopened this Jan 9, 2023
@wwbitejotunn wwbitejotunn marked this pull request as draft January 10, 2023 03:22
@wwbitejotunn wwbitejotunn marked this pull request as ready for review January 10, 2023 03:23
Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@heavengate heavengate merged commit a48b8e2 into PaddlePaddle:develop Jan 13, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants