Skip to content

Commit 36edb0e

Browse files
authored
[cherry-pick]Add sparse attention cherrypick (PaddlePaddle#36447)
The code of this PR can only support CUDA 11.2. Currently, CI does not have GPU with CUDA 11.2 , and all tests will be skipped automatically. The new OP is paddle._C_ops.sparse_attention. Regarding the work of the python API, it will be resolved in a follow-up PR. The code of this PR lacks tests on dynamic graphs and static graphs, and will be added in subsequent PRs.
1 parent d974dbd commit 36edb0e

File tree

8 files changed

+960
-8
lines changed

8 files changed

+960
-8
lines changed

cmake/operators.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ function(op_library TARGET)
214214
foreach(manual_pybind_op "compare_all_op" "compare_op" "logical_op" "bitwise_op" "nccl_op"
215215
"tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op"
216216
"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op"
217-
"sync_batch_norm_op" "dgc_op" "fused_fc_elementwise_layernorm_op"
217+
"sync_batch_norm_op" "sparse_attention_op" "dgc_op" "fused_fc_elementwise_layernorm_op"
218218
"skip_layernorm_op" "multihead_matmul_op" "fusion_group_op" "fused_bn_activation_op" "fused_embedding_eltwise_layernorm_op" "fusion_gru_op" "fusion_lstm_op"
219219
"fused_bn_add_activation_op")
220220
if ("${TARGET}" STREQUAL "${manual_pybind_op}")

paddle/fluid/operators/CMakeLists.txt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ if(WITH_UNITY_BUILD)
7878
include(unity_build_rule.cmake)
7979
endif()
8080

81-
register_operators(EXCLUDES py_layer_op py_func_op warpctc_op dgc_op lstm_op run_program_op eye_op recurrent_op
81+
register_operators(EXCLUDES py_layer_op py_func_op warpctc_op dgc_op sparse_attention_op lstm_op run_program_op eye_op recurrent_op
8282
sync_batch_norm_op spectral_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS})
8383

8484
op_library(run_program_op SRCS run_program_op.cc run_program_op.cu.cc DEPS executor_cache ${OP_HEADER_DEPS})
@@ -94,6 +94,10 @@ if (WITH_GPU OR WITH_ROCM)
9494
endif()
9595
op_library(sync_batch_norm_op)
9696
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(sync_batch_norm);\n")
97+
if ((NOT WIN32) AND (NOT WITH_ROCM) AND (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_LESS 11.2) )
98+
op_library(sparse_attention_op)
99+
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(sparse_attention);\n")
100+
endif()
97101
else()
98102
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
99103
endif()
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <string>
16+
#include <vector>
17+
#include "paddle/fluid/framework/data_type.h"
18+
#include "paddle/fluid/framework/op_registry.h"
19+
20+
namespace paddle {
21+
namespace operators {
22+
23+
class SparseAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
24+
public:
25+
void Make() override {
26+
AddInput(
27+
"Q",
28+
"(Tensor), The input tensor of query in attention, "
29+
"whose dimension : `[batch_size, num_heads, target_len, head_dim]`.");
30+
AddInput(
31+
"K",
32+
"(Tensor), The input tensor of key in attention, "
33+
"whose dimension : `[batch_size, num_heads, target_len, head_dim]`.");
34+
AddInput(
35+
"V",
36+
"(Tensor), The input tensor of value in attention, "
37+
"whose dimension : `[batch_size, num_heads, target_len, head_dim]`.");
38+
AddInput("Offset",
39+
"(Tensor, default: Tensor<int32>), The input tensor of offset in "
40+
"CSR sparse format, "
41+
"whose dimension : `[batch_size, num_heads, target_len + 1]`.");
42+
AddInput("Columns",
43+
"(Tensor, default: Tensor<int32>), The input tensor of columns in "
44+
"CSR sparse format, "
45+
"whose dimension : `[batch_size, num_heads, sparse_nnz_num]`.");
46+
AddOutput(
47+
"Out",
48+
"(Tensor), The output tensor of result in attention, "
49+
"whose dimension : `[batch_size, num_heads, target_len, head_dim]`.");
50+
AddOutput("SparseDotSdd",
51+
"(Tensor), The output tensor of result in SparseDotSdd step, "
52+
"whose dimension : `[batch_size, num_heads, sparse_nnz_dim]`.")
53+
.AsIntermediate();
54+
AddOutput("Softmax",
55+
"(Tensor), The output tensor of result in Softmax step, "
56+
"whose dimension : `[batch_size, num_heads, sparse_nnz_dim]`.")
57+
.AsIntermediate();
58+
AddComment(R"DOC(
59+
Compute the value of the sparse attention module. Its input value includes five tensors.
60+
Q, K, and V represent query, key, and value in the Attention module, respectively.
61+
The CSR format is used to represent the sparsity feature in the Attention module.
62+
The CSR format contains two tensors, offset and columns.
63+
)DOC");
64+
}
65+
};
66+
67+
class SparseAttentionOp : public framework::OperatorWithKernel {
68+
public:
69+
using framework::OperatorWithKernel::OperatorWithKernel;
70+
void InferShape(framework::InferShapeContext* ctx) const override {
71+
OP_INOUT_CHECK(ctx->HasInput("Q"), "Input", "Q", "sparse_attention");
72+
OP_INOUT_CHECK(ctx->HasInput("K"), "Input", "K", "sparse_attention");
73+
OP_INOUT_CHECK(ctx->HasInput("V"), "Input", "V", "sparse_attention");
74+
OP_INOUT_CHECK(ctx->HasInput("Offset"), "Input", "Offset",
75+
"sparse_attention");
76+
OP_INOUT_CHECK(ctx->HasInput("Columns"), "Input", "Columns",
77+
"sparse_attention");
78+
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "sparse_attention");
79+
OP_INOUT_CHECK(ctx->HasOutput("SparseDotSdd"), "Output", "SparseDotSdd",
80+
"sparse_attention");
81+
OP_INOUT_CHECK(ctx->HasOutput("Softmax"), "Output", "Softmax",
82+
"sparse_attention");
83+
84+
auto dims_q = ctx->GetInputDim("Q");
85+
auto dims_k = ctx->GetInputDim("K");
86+
auto dims_v = ctx->GetInputDim("V");
87+
auto dims_columns = ctx->GetInputDim("Columns");
88+
89+
PADDLE_ENFORCE_EQ(dims_q.size(), static_cast<size_t>(4),
90+
platform::errors::InvalidArgument(
91+
"Dimension in query' shapes should be 4."));
92+
PADDLE_ENFORCE_EQ(dims_k.size(), static_cast<size_t>(4),
93+
platform::errors::InvalidArgument(
94+
"Dimension in key' shapes should be 4."));
95+
PADDLE_ENFORCE_EQ(dims_v.size(), static_cast<size_t>(4),
96+
platform::errors::InvalidArgument(
97+
"Dimension in value' shapes should be 4."));
98+
99+
auto batch_size = dims_q[0];
100+
auto num_heads = dims_q[1];
101+
auto M = dims_q[2];
102+
auto N = dims_q[3];
103+
auto sparse_nnz = dims_columns[2];
104+
ctx->SetOutputDim("Out", {batch_size, num_heads, M, N});
105+
ctx->SetOutputDim("SparseDotSdd", {batch_size, num_heads, sparse_nnz});
106+
ctx->SetOutputDim("Softmax", {batch_size, num_heads, sparse_nnz});
107+
ctx->ShareLoD("Q", "Out");
108+
}
109+
110+
protected:
111+
framework::OpKernelType GetExpectedKernelType(
112+
const framework::ExecutionContext& ctx) const override {
113+
auto input_data_type =
114+
OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "Q", "K");
115+
return framework::OpKernelType(input_data_type, ctx.GetPlace());
116+
}
117+
};
118+
119+
class SparseAttentionOpGrad : public framework::OperatorWithKernel {
120+
public:
121+
using framework::OperatorWithKernel::OperatorWithKernel;
122+
123+
protected:
124+
void InferShape(framework::InferShapeContext* ctx) const override {
125+
OP_INOUT_CHECK(ctx->HasInput("Q"), "Input", "Q", "sparse_attention_grad");
126+
OP_INOUT_CHECK(ctx->HasInput("K"), "Input", "K", "sparse_attention_grad");
127+
OP_INOUT_CHECK(ctx->HasInput("V"), "Input", "V", "sparse_attention_grad");
128+
OP_INOUT_CHECK(ctx->HasInput("Offset"), "Input", "Offset",
129+
"sparse_attention_grad");
130+
OP_INOUT_CHECK(ctx->HasInput("Columns"), "Input", "Columns",
131+
"sparse_attention_grad");
132+
OP_INOUT_CHECK(ctx->HasInput("SparseDotSdd"), "Input", "SparseDotSdd",
133+
"sparse_attention_grad");
134+
OP_INOUT_CHECK(ctx->HasInput("Softmax"), "Input", "Softmax",
135+
"sparse_attention_grad");
136+
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
137+
"Out@GRAD", "sparse_attention_grad");
138+
139+
auto x_grad_name = framework::GradVarName("Q");
140+
auto y_grad_name = framework::GradVarName("K");
141+
auto z_grad_name = framework::GradVarName("V");
142+
143+
if (ctx->HasOutput(x_grad_name)) {
144+
ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("Q"));
145+
}
146+
if (ctx->HasOutput(y_grad_name)) {
147+
ctx->SetOutputDim(y_grad_name, ctx->GetInputDim("K"));
148+
}
149+
if (ctx->HasOutput(z_grad_name)) {
150+
ctx->SetOutputDim(z_grad_name, ctx->GetInputDim("V"));
151+
}
152+
}
153+
154+
framework::OpKernelType GetExpectedKernelType(
155+
const framework::ExecutionContext& ctx) const override {
156+
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
157+
ctx, framework::GradVarName("Out")),
158+
ctx.GetPlace());
159+
}
160+
};
161+
162+
template <typename T>
163+
class SparseAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
164+
public:
165+
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
166+
167+
protected:
168+
void Apply(GradOpPtr<T> op) const override {
169+
op->SetType("sparse_attention_grad");
170+
op->SetInput("Q", this->Input("Q"));
171+
op->SetInput("K", this->Input("K"));
172+
op->SetInput("V", this->Input("V"));
173+
op->SetInput("Offset", this->Input("Offset"));
174+
op->SetInput("Columns", this->Input("Columns"));
175+
op->SetInput("SparseDotSdd", this->Output("SparseDotSdd"));
176+
op->SetInput("Softmax", this->Output("Softmax"));
177+
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
178+
op->SetOutput(framework::GradVarName("Q"), this->InputGrad("Q"));
179+
op->SetOutput(framework::GradVarName("K"), this->InputGrad("K"));
180+
op->SetOutput(framework::GradVarName("V"), this->InputGrad("V"));
181+
}
182+
};
183+
184+
} // namespace operators
185+
} // namespace paddle
186+
187+
namespace ops = paddle::operators;
188+
REGISTER_OPERATOR(sparse_attention, ops::SparseAttentionOp,
189+
ops::SparseAttentionOpMaker,
190+
ops::SparseAttentionGradOpMaker<paddle::framework::OpDesc>,
191+
ops::SparseAttentionGradOpMaker<paddle::imperative::OpBase>);
192+
193+
REGISTER_OPERATOR(sparse_attention_grad, ops::SparseAttentionOpGrad);

0 commit comments

Comments
 (0)