Skip to content

Commit 10225d2

Browse files
author
zhangkaihuo
authored
[cherry-pick]Sparse static graph (#46838)
cherry-pick : #46322, #46245 Sparse API 支持静态图
1 parent 976af0d commit 10225d2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+937
-74
lines changed

paddle/fluid/framework/CMakeLists.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ cc_test(
190190
cc_library(
191191
var_type_traits
192192
SRCS var_type_traits.cc
193-
DEPS framework_proto scope tensor_array)
193+
DEPS framework_proto scope tensor_array sparse_coo_tensor)
194194
if(WITH_GPU)
195195
target_link_libraries(var_type_traits dynload_cuda)
196196
endif()
@@ -1138,7 +1138,8 @@ cc_library(
11381138
phi
11391139
phi_api_utils
11401140
op_info
1141-
shape_inference)
1141+
shape_inference
1142+
sparse_coo_tensor)
11421143
cc_test(
11431144
infershape_utils_test
11441145
SRCS infershape_utils_test.cc

paddle/fluid/framework/feed_fetch_type.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,11 @@ limitations under the License. */
2222

2323
namespace paddle {
2424
namespace framework {
25-
using FeedType = paddle::variant<LoDTensor, Strings>;
25+
using FeedType = paddle::variant<LoDTensor, Strings, phi::SparseCooTensor>;
2626
using FeedList = std::vector<FeedType>;
2727

28-
using FetchType = paddle::variant<LoDTensor, LoDTensorArray, framework::Vocab>;
28+
using FetchType = paddle::
29+
variant<LoDTensor, LoDTensorArray, framework::Vocab, phi::SparseCooTensor>;
2930
using FetchList = std::vector<FetchType>;
3031

3132
using FetchUnmergedList = std::vector<std::vector<FetchType>>;
@@ -52,6 +53,13 @@ inline bool data_is_string_tensor(const FeedType &data) {
5253
return false;
5354
}
5455

56+
inline bool data_is_sparse_coo_tensor(const FetchType &data) {
57+
if (data.type() == typeid(phi::SparseCooTensor)) {
58+
return true;
59+
}
60+
return false;
61+
}
62+
5563
static const char kFeedOpType[] = "feed";
5664
static const char kFetchOpType[] = "fetch";
5765

paddle/fluid/framework/framework.proto

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,8 @@ message VarType {
154154
FEED_LIST = 28;
155155
// The data type of phi::StringTensor
156156
PSTRING = 29;
157+
// the data type of phi::SparseCooTensor
158+
SPARSE_COO = 30;
157159
}
158160

159161
required Type type = 1;
@@ -186,6 +188,7 @@ message VarType {
186188
optional TensorDesc string = 8;
187189
optional TensorDesc strings = 9;
188190
optional TensorDesc vocab = 10;
191+
optional TensorDesc sparse_coo = 11;
189192
}
190193

191194
message VarDesc {

paddle/fluid/framework/infershape_utils.cc

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,11 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext {
101101
});
102102
}
103103

104+
bool IsSparseCooTensorInput(const std::string& name) const override {
105+
auto var_type = ctx_.GetInputVarType(name);
106+
return var_type == proto::VarType::SPARSE_COO;
107+
}
108+
104109
bool IsDenseTensorOutput(const std::string& name) const override {
105110
auto var_types = ctx_.GetOutputsVarType(name);
106111
return std::all_of(var_types.begin(),
@@ -145,6 +150,26 @@ int64_t CompatMetaTensor::numel() const {
145150
}
146151
}
147152

153+
bool CompatMetaTensor::is_dense() const {
154+
if (is_runtime_) {
155+
auto* var = PADDLE_GET_CONST(Variable*, var_);
156+
return var->IsType<phi::DenseTensor>();
157+
} else {
158+
auto* var = PADDLE_GET_CONST(VarDesc*, var_);
159+
return var->GetType() == proto::VarType::LOD_TENSOR;
160+
}
161+
}
162+
163+
bool CompatMetaTensor::is_tensor_array() const {
164+
if (is_runtime_) {
165+
auto* var = PADDLE_GET_CONST(Variable*, var_);
166+
return var->IsType<framework::LoDTensorArray>();
167+
} else {
168+
auto* var = PADDLE_GET_CONST(VarDesc*, var_);
169+
return var->GetType() == proto::VarType::LOD_TENSOR_ARRAY;
170+
}
171+
}
172+
148173
DDim CompatMetaTensor::dims() const {
149174
ValidCheck(*this);
150175
if (is_runtime_) {
@@ -153,6 +178,8 @@ DDim CompatMetaTensor::dims() const {
153178
return var->Get<phi::DenseTensor>().dims();
154179
} else if (var->IsType<phi::SelectedRows>()) {
155180
return var->Get<phi::SelectedRows>().dims();
181+
} else if (var->IsType<phi::SparseCooTensor>()) {
182+
return var->Get<phi::SparseCooTensor>().dims();
156183
} else if (var->IsType<framework::LoDTensorArray>()) {
157184
// use tensor array size as dims
158185
auto& tensor_array = var->Get<framework::LoDTensorArray>();
@@ -178,6 +205,8 @@ phi::DataType CompatMetaTensor::dtype() const {
178205
return var->Get<phi::DenseTensor>().dtype();
179206
} else if (var->IsType<phi::SelectedRows>()) {
180207
return var->Get<phi::SelectedRows>().dtype();
208+
} else if (var->IsType<phi::SparseCooTensor>()) {
209+
return var->Get<phi::SparseCooTensor>().dtype();
181210
} else if (var->IsType<framework::LoDTensorArray>()) {
182211
// NOTE(chenweihang): do nothing
183212
// Unsupported get dtype from LoDTensorArray now
@@ -200,6 +229,8 @@ DataLayout CompatMetaTensor::layout() const {
200229
return var->Get<phi::DenseTensor>().layout();
201230
} else if (var->IsType<phi::SelectedRows>()) {
202231
return var->Get<phi::SelectedRows>().layout();
232+
} else if (var->IsType<phi::SparseCooTensor>()) {
233+
return var->Get<phi::SparseCooTensor>().layout();
203234
} else if (var->IsType<framework::LoDTensorArray>()) {
204235
// NOTE(chenweihang): do nothing
205236
// Unsupported get layout from LoDTensorArray now
@@ -226,6 +257,9 @@ void CompatMetaTensor::set_dims(const DDim& dims) {
226257
} else if (var->IsType<phi::SelectedRows>()) {
227258
auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value();
228259
phi::DenseTensorUtils::GetMutableMeta(tensor)->dims = dims;
260+
} else if (var->IsType<phi::SparseCooTensor>()) {
261+
auto* tensor = var->GetMutable<phi::SparseCooTensor>();
262+
phi::DenseTensorUtils::GetMutableMeta(tensor)->dims = dims;
229263
} else if (var->IsType<framework::LoDTensorArray>()) {
230264
auto* tensor_array = var->GetMutable<framework::LoDTensorArray>();
231265
// Note: Here I want enforce `tensor_array->size() == 0UL`, because
@@ -257,6 +291,9 @@ void CompatMetaTensor::set_dtype(phi::DataType dtype) {
257291
} else if (var->IsType<phi::SelectedRows>()) {
258292
auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value();
259293
phi::DenseTensorUtils::GetMutableMeta(tensor)->dtype = dtype;
294+
} else if (var->IsType<phi::SparseCooTensor>()) {
295+
auto* tensor = var->GetMutable<phi::SparseCooTensor>();
296+
phi::DenseTensorUtils::GetMutableMeta(tensor)->dtype = dtype;
260297
} else if (var->IsType<framework::LoDTensorArray>()) {
261298
// NOTE(chenweihang): do nothing
262299
// Unsupported set dtype for LoDTensorArray now
@@ -280,6 +317,9 @@ void CompatMetaTensor::set_layout(DataLayout layout) {
280317
} else if (var->IsType<phi::SelectedRows>()) {
281318
auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value();
282319
phi::DenseTensorUtils::GetMutableMeta(tensor)->layout = layout;
320+
} else if (var->IsType<phi::SparseCooTensor>()) {
321+
auto* tensor = var->GetMutable<phi::SparseCooTensor>();
322+
phi::DenseTensorUtils::GetMutableMeta(tensor)->layout = layout;
283323
} else if (var->IsType<framework::LoDTensorArray>()) {
284324
// NOTE(chenweihang): do nothing
285325
// Unsupported set dtype for LoDTensorArray now
@@ -299,7 +339,7 @@ void CompatMetaTensor::share_lod(const MetaTensor& meta_tensor) {
299339
ValidCheck(meta_tensor);
300340
if (is_runtime_) {
301341
auto* var = PADDLE_GET(Variable*, var_);
302-
if (var->IsType<phi::DenseTensor>()) {
342+
if (var->IsType<phi::DenseTensor>() && meta_tensor.is_dense()) {
303343
auto* tensor = var->GetMutable<phi::DenseTensor>();
304344
phi::DenseTensorUtils::GetMutableMeta(tensor)->lod =
305345
static_cast<const CompatMetaTensor&>(meta_tensor).GetRuntimeLoD();
@@ -309,6 +349,10 @@ void CompatMetaTensor::share_lod(const MetaTensor& meta_tensor) {
309349
}
310350
} else {
311351
auto* var = PADDLE_GET(VarDesc*, var_);
352+
if (!meta_tensor.is_dense() && !meta_tensor.is_tensor_array()) {
353+
VLOG(3) << "input metatensor is not LoDTensor or LoDTensorArray.";
354+
return;
355+
}
312356
var->SetLoDLevel(
313357
static_cast<const CompatMetaTensor&>(meta_tensor).GetCompileTimeLoD());
314358
}

paddle/fluid/framework/infershape_utils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ class CompatMetaTensor : public phi::MetaTensor {
5959

6060
bool initialized() const override { return initialized_; };
6161

62+
bool is_tensor_array() const;
63+
bool is_dense() const;
64+
6265
operator unspecified_bool_type() const override {
6366
return initialized_ ? unspecified_bool_true : 0;
6467
}

paddle/fluid/framework/operator.cc

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2382,6 +2382,17 @@ void OperatorWithKernel::ParseInputDataType(
23822382
t = &var->Get<LoDTensor>();
23832383
} else if (var->IsType<phi::SelectedRows>()) {
23842384
t = &(var->Get<phi::SelectedRows>().value());
2385+
} else if (var->IsType<phi::SparseCooTensor>()) {
2386+
const phi::SparseCooTensor* sp_t = &(var->Get<phi::SparseCooTensor>());
2387+
PADDLE_ENFORCE_EQ(
2388+
sp_t->initialized(),
2389+
true,
2390+
platform::errors::InvalidArgument("The %s Op's Input Variable `%s` "
2391+
"contains uninitialized Tensor.",
2392+
Type(),
2393+
name));
2394+
*data_type = paddle::framework::TransToProtoVarType(sp_t->dtype());
2395+
return;
23852396
} else if (var->IsType<LoDTensorArray>()) {
23862397
auto t_arr = &var->Get<LoDTensorArray>();
23872398
for (size_t j = 0; j < t_arr->size(); j++) {
@@ -2419,6 +2430,29 @@ void OperatorWithKernel::ParseMultiInputDataType(
24192430
t = &var->Get<LoDTensor>();
24202431
} else if (var->IsType<phi::SelectedRows>()) {
24212432
t = &(var->Get<phi::SelectedRows>().value());
2433+
} else if (var->IsType<phi::SparseCooTensor>()) {
2434+
const phi::SparseCooTensor* sp_t = &(var->Get<phi::SparseCooTensor>());
2435+
PADDLE_ENFORCE_EQ(
2436+
sp_t->initialized(),
2437+
true,
2438+
platform::errors::InvalidArgument("The %s Op's Input Variable `%s` "
2439+
"contains uninitialized Tensor.",
2440+
Type(),
2441+
name));
2442+
proto::VarType::Type tmp =
2443+
paddle::framework::TransToProtoVarType(sp_t->dtype());
2444+
PADDLE_ENFORCE(tmp == *data_type || *data_type == default_data_type,
2445+
platform::errors::InvalidArgument(
2446+
"The DataType of %s Op's duplicable or different "
2447+
"slot Variable %s must be "
2448+
"consistent or reigster GetExpectedKernelType. The "
2449+
"current variable type is (%s), but the "
2450+
"previous variable type is (%s).",
2451+
Type(),
2452+
name,
2453+
DataTypeToString(tmp),
2454+
DataTypeToString(*data_type)));
2455+
*data_type = tmp;
24222456
} else if (var->IsType<LoDTensorArray>()) {
24232457
auto t_arr = &var->Get<LoDTensorArray>();
24242458
for (size_t j = 0; j < t_arr->size(); j++) {
@@ -2663,6 +2697,9 @@ void OperatorWithKernel::BuildPhiKernelContext(
26632697
} else if (var->IsType<phi::SelectedRows>()) {
26642698
tensor_in = &(var->Get<phi::SelectedRows>());
26652699
phi_kernel_context->EmplaceBackInputWithoutSetRange(tensor_in);
2700+
} else if (var->IsType<phi::SparseCooTensor>()) {
2701+
tensor_in = &(var->Get<phi::SparseCooTensor>());
2702+
phi_kernel_context->EmplaceBackInputWithoutSetRange(tensor_in);
26662703
} else if (var->IsType<framework::LoDTensorArray>()) {
26672704
need_prepare_phi_data_ = true;
26682705
tensor_in = &(var->Get<framework::LoDTensorArray>());
@@ -2708,6 +2745,9 @@ void OperatorWithKernel::BuildPhiKernelContext(
27082745
} else if (var->template IsType<phi::SelectedRows>()) {
27092746
tensor_out = var->template GetMutable<phi::SelectedRows>();
27102747
phi_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out);
2748+
} else if (var->template IsType<phi::SparseCooTensor>()) {
2749+
tensor_out = var->template GetMutable<phi::SparseCooTensor>();
2750+
phi_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out);
27112751
} else if (var->template IsType<framework::LoDTensorArray>()) {
27122752
tensor_out = var->template GetMutable<framework::LoDTensorArray>();
27132753
// Note: If the input LoDTensorArray size is 0, the output

paddle/fluid/framework/operator.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,11 @@ class ExecutionArgumentMappingContext : public phi::ArgumentMappingContext {
524524
});
525525
}
526526

527+
bool IsSparseCooTensorInput(const std::string& name) const override {
528+
const auto* var = ctx_.InputVar(name);
529+
return var->IsType<phi::SparseCooTensor>();
530+
}
531+
527532
bool IsDenseTensorOutput(const std::string& name) const override {
528533
auto vars = ctx_.MultiOutputVar(name);
529534
return std::all_of(vars.begin(), vars.end(), [](const Variable* var) {

paddle/fluid/framework/tensor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License. */
1717
#include "paddle/fluid/framework/data_type.h"
1818
#include "paddle/fluid/framework/mixed_vector.h"
1919
#include "paddle/phi/core/dense_tensor.h"
20+
#include "paddle/phi/core/sparse_coo_tensor.h"
2021

2122
namespace paddle {
2223
namespace framework {

paddle/fluid/framework/var_desc.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,8 @@ const proto::VarType::TensorDesc &VarDesc::tensor_desc() const {
237237
return desc_.type().strings();
238238
case proto::VarType::VOCAB:
239239
return desc_.type().vocab();
240+
case proto::VarType::SPARSE_COO:
241+
return desc_.type().sparse_coo();
240242
default:
241243
PADDLE_THROW(platform::errors::Unavailable(
242244
"Getting 'tensor_desc' is not supported by the %s type variable.",
@@ -284,6 +286,8 @@ proto::VarType::TensorDesc *VarDesc::mutable_tensor_desc() {
284286
return desc_.mutable_type()->mutable_strings();
285287
case proto::VarType::VOCAB:
286288
return desc_.mutable_type()->mutable_vocab();
289+
case proto::VarType::SPARSE_COO:
290+
return desc_.mutable_type()->mutable_sparse_coo();
287291
default:
288292
PADDLE_THROW(
289293
platform::errors::Unavailable("Getting 'mutable_tensor_desc' is not "

paddle/fluid/framework/var_type.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ inline proto::VarType::Type ToVarType(int type) {
3333
switch (type) {
3434
case proto::VarType::LOD_TENSOR:
3535
case proto::VarType::SELECTED_ROWS:
36+
case proto::VarType::SPARSE_COO:
3637
case proto::VarType::LOD_RANK_TABLE:
3738
case proto::VarType::LOD_TENSOR_ARRAY:
3839
case proto::VarType::FETCH_LIST:
@@ -59,6 +60,9 @@ inline void VisitVarType(const framework::Variable& var, Visitor visitor) {
5960
case proto::VarType::SELECTED_ROWS:
6061
visitor(var.Get<phi::SelectedRows>());
6162
return;
63+
case proto::VarType::SPARSE_COO:
64+
visitor(var.Get<phi::SparseCooTensor>());
65+
return;
6266
case proto::VarType::READER:
6367
visitor(var.Get<ReaderHolder>());
6468
return;

0 commit comments

Comments
 (0)