- 
                Notifications
    You must be signed in to change notification settings 
- Fork 5.9k
add oss flash fmha and fmhca support #49438
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add oss flash fmha and fmhca support #49438
Conversation
| 你的PR提交成功,感谢你对开源项目的贡献! | 
4cecf1d    to
    9c1e126      
    Compare
  
    |  | ||
| void TrtCrossMultiHeadMatmulFusePass::ApplyImpl(Graph* graph) const { | ||
| FusePassBase::Init(name_scope_, graph); | ||
| #ifdef PADDLE_WITH_TENSORRT | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个if 宏是不是应该包到562行后
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我理解这里进行early stop的话是要在build fusion之前进行, 因此放在ApplyImpl的最开头.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
经过讨论后, 目前采用运行期判定的方式进行early stop, 麻烦辛苦再看一下
        
          
                paddle/fluid/framework/ir/trt_flash_multihead_matmul_fuse_pass.cc
              
                Outdated
          
            Show resolved
            Hide resolved
        
      | FusePassBase::Init(name_scope_, graph); | ||
| auto* scope = param_scope(); | ||
|  | ||
| #ifdef PADDLE_WITH_TENSORRT | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
经过讨论后, 目前采用运行期判定的方式进行early stop, 麻烦辛苦再看一下
        
          
                paddle/fluid/inference/tensorrt/convert/flash_multihead_matmul_op.cc
              
                Outdated
          
            Show resolved
            Hide resolved
        
      | std::get<2>(trt_version) * 10 < | ||
| 8520) { | ||
| VLOG(3) << "Flash attention oss plugin only available for trt version >= " | ||
| "8.5.2.2. Stop this pass"; | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里只是输出了日志?应该return吧?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
看着pass有限定trt 8.5.2.2才注册,这里应该不用判断了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
考虑后还是在这里进行runtime时的early stop, 因此现在在这里加上了return
        
          
                paddle/fluid/framework/ir/trt_flash_multihead_matmul_fuse_pass.cc
              
                Outdated
          
            Show resolved
            Hide resolved
        
              
          
                paddle/fluid/framework/ir/trt_flash_multihead_matmul_fuse_pass.cc
              
                Outdated
          
            Show resolved
            Hide resolved
        
      refine compile fix compile
9691d16    to
    6a679e8      
    Compare
  
    There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Performance optimization
PR changes
OPs
Describe
Add nvidia tensorrt oss plugin flash attention and cross attention support to accelerate the inference speed of stable diffusion and other models.
Using flash attention and cross attention plugin, the stable diffusion latency can be speed up from 1.52s to 1.02s.
Tensorrt 8.5.2 is required to using those plugins.
Using nsys, we can see that the plugin are successful involved by unit test under trt8.5.2.2 environment