@@ -101,6 +101,11 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext {
101101 });
102102 }
103103
104+ bool IsSparseCooTensorInput (const std::string& name) const override {
105+ auto var_type = ctx_.GetInputVarType (name);
106+ return var_type == proto::VarType::SPARSE_COO;
107+ }
108+
104109 bool IsDenseTensorOutput (const std::string& name) const override {
105110 auto var_types = ctx_.GetOutputsVarType (name);
106111 return std::all_of (var_types.begin (),
@@ -145,6 +150,26 @@ int64_t CompatMetaTensor::numel() const {
145150 }
146151}
147152
153+ bool CompatMetaTensor::is_dense () const {
154+ if (is_runtime_) {
155+ auto * var = PADDLE_GET_CONST (Variable*, var_);
156+ return var->IsType <phi::DenseTensor>();
157+ } else {
158+ auto * var = PADDLE_GET_CONST (VarDesc*, var_);
159+ return var->GetType () == proto::VarType::LOD_TENSOR;
160+ }
161+ }
162+
163+ bool CompatMetaTensor::is_tensor_array () const {
164+ if (is_runtime_) {
165+ auto * var = PADDLE_GET_CONST (Variable*, var_);
166+ return var->IsType <framework::LoDTensorArray>();
167+ } else {
168+ auto * var = PADDLE_GET_CONST (VarDesc*, var_);
169+ return var->GetType () == proto::VarType::LOD_TENSOR_ARRAY;
170+ }
171+ }
172+
148173DDim CompatMetaTensor::dims () const {
149174 ValidCheck (*this );
150175 if (is_runtime_) {
@@ -153,6 +178,8 @@ DDim CompatMetaTensor::dims() const {
153178 return var->Get <phi::DenseTensor>().dims ();
154179 } else if (var->IsType <phi::SelectedRows>()) {
155180 return var->Get <phi::SelectedRows>().dims ();
181+ } else if (var->IsType <phi::SparseCooTensor>()) {
182+ return var->Get <phi::SparseCooTensor>().dims ();
156183 } else if (var->IsType <framework::LoDTensorArray>()) {
157184 // use tensor array size as dims
158185 auto & tensor_array = var->Get <framework::LoDTensorArray>();
@@ -178,6 +205,8 @@ phi::DataType CompatMetaTensor::dtype() const {
178205 return var->Get <phi::DenseTensor>().dtype ();
179206 } else if (var->IsType <phi::SelectedRows>()) {
180207 return var->Get <phi::SelectedRows>().dtype ();
208+ } else if (var->IsType <phi::SparseCooTensor>()) {
209+ return var->Get <phi::SparseCooTensor>().dtype ();
181210 } else if (var->IsType <framework::LoDTensorArray>()) {
182211 // NOTE(chenweihang): do nothing
183212 // Unsupported get dtype from LoDTensorArray now
@@ -200,6 +229,8 @@ DataLayout CompatMetaTensor::layout() const {
200229 return var->Get <phi::DenseTensor>().layout ();
201230 } else if (var->IsType <phi::SelectedRows>()) {
202231 return var->Get <phi::SelectedRows>().layout ();
232+ } else if (var->IsType <phi::SparseCooTensor>()) {
233+ return var->Get <phi::SparseCooTensor>().layout ();
203234 } else if (var->IsType <framework::LoDTensorArray>()) {
204235 // NOTE(chenweihang): do nothing
205236 // Unsupported get layout from LoDTensorArray now
@@ -226,6 +257,9 @@ void CompatMetaTensor::set_dims(const DDim& dims) {
226257 } else if (var->IsType <phi::SelectedRows>()) {
227258 auto * tensor = var->GetMutable <phi::SelectedRows>()->mutable_value ();
228259 phi::DenseTensorUtils::GetMutableMeta (tensor)->dims = dims;
260+ } else if (var->IsType <phi::SparseCooTensor>()) {
261+ auto * tensor = var->GetMutable <phi::SparseCooTensor>();
262+ phi::DenseTensorUtils::GetMutableMeta (tensor)->dims = dims;
229263 } else if (var->IsType <framework::LoDTensorArray>()) {
230264 auto * tensor_array = var->GetMutable <framework::LoDTensorArray>();
231265 // Note: Here I want enforce `tensor_array->size() == 0UL`, because
@@ -257,6 +291,9 @@ void CompatMetaTensor::set_dtype(phi::DataType dtype) {
257291 } else if (var->IsType <phi::SelectedRows>()) {
258292 auto * tensor = var->GetMutable <phi::SelectedRows>()->mutable_value ();
259293 phi::DenseTensorUtils::GetMutableMeta (tensor)->dtype = dtype;
294+ } else if (var->IsType <phi::SparseCooTensor>()) {
295+ auto * tensor = var->GetMutable <phi::SparseCooTensor>();
296+ phi::DenseTensorUtils::GetMutableMeta (tensor)->dtype = dtype;
260297 } else if (var->IsType <framework::LoDTensorArray>()) {
261298 // NOTE(chenweihang): do nothing
262299 // Unsupported set dtype for LoDTensorArray now
@@ -280,6 +317,9 @@ void CompatMetaTensor::set_layout(DataLayout layout) {
280317 } else if (var->IsType <phi::SelectedRows>()) {
281318 auto * tensor = var->GetMutable <phi::SelectedRows>()->mutable_value ();
282319 phi::DenseTensorUtils::GetMutableMeta (tensor)->layout = layout;
320+ } else if (var->IsType <phi::SparseCooTensor>()) {
321+ auto * tensor = var->GetMutable <phi::SparseCooTensor>();
322+ phi::DenseTensorUtils::GetMutableMeta (tensor)->layout = layout;
283323 } else if (var->IsType <framework::LoDTensorArray>()) {
284324 // NOTE(chenweihang): do nothing
285325 // Unsupported set dtype for LoDTensorArray now
@@ -299,7 +339,7 @@ void CompatMetaTensor::share_lod(const MetaTensor& meta_tensor) {
299339 ValidCheck (meta_tensor);
300340 if (is_runtime_) {
301341 auto * var = PADDLE_GET (Variable*, var_);
302- if (var->IsType <phi::DenseTensor>()) {
342+ if (var->IsType <phi::DenseTensor>() && meta_tensor. is_dense () ) {
303343 auto * tensor = var->GetMutable <phi::DenseTensor>();
304344 phi::DenseTensorUtils::GetMutableMeta (tensor)->lod =
305345 static_cast <const CompatMetaTensor&>(meta_tensor).GetRuntimeLoD ();
@@ -309,6 +349,10 @@ void CompatMetaTensor::share_lod(const MetaTensor& meta_tensor) {
309349 }
310350 } else {
311351 auto * var = PADDLE_GET (VarDesc*, var_);
352+ if (!meta_tensor.is_dense () && !meta_tensor.is_tensor_array ()) {
353+ VLOG (3 ) << " input metatensor is not LoDTensor or LoDTensorArray." ;
354+ return ;
355+ }
312356 var->SetLoDLevel (
313357 static_cast <const CompatMetaTensor&>(meta_tensor).GetCompileTimeLoD ());
314358 }
0 commit comments