Skip to content

Commit 061d226

Browse files
committed
Fix
1 parent 4a071e2 commit 061d226

File tree

11 files changed

+107
-26
lines changed

11 files changed

+107
-26
lines changed

paddle/fluid/framework/type_info.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ bool TypeInfoTraits<BaseT, DerivedT>::classof(const BaseT* obj) {
3939
}
4040

4141
template class TypeInfoTraits<phi::TensorBase, paddle::framework::RawTensor>;
42-
template class TypeInfoTraits<phi::TensorBase, paddle::framework::FeedList>;
4342
template class TypeInfoTraits<phi::TensorBase, egr::VariableCompatTensor>;
4443
template class TypeInfoTraits<phi::TensorBase, paddle::prim::DescTensor>;
4544
template class TypeInfoTraits<phi::TensorBase, paddle::primitive::LazyTensor>;

paddle/fluid/operators/controlflow/feed_op.cc

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -53,27 +53,6 @@ const framework::FeedType& CheckAndGetFeedItem(const phi::ExtendedTensor& x,
5353
return feed_list->at(static_cast<size_t>(col));
5454
}
5555

56-
template <typename Context>
57-
void FeedDenseTensorKernel(const Context& dev_ctx,
58-
const phi::ExtendedTensor& x,
59-
int col,
60-
phi::DenseTensor* out) {
61-
PADDLE_ENFORCE_NOT_NULL(
62-
out,
63-
common::errors::NotFound(
64-
"Output cannot be found in scope for operator 'Feed'"));
65-
const auto& feed_item = CheckAndGetFeedItem(x, col);
66-
const auto& in_tensor = paddle::get<phi::DenseTensor>(feed_item);
67-
const auto& place = dev_ctx.GetPlace();
68-
if (phi::is_same_place(in_tensor.place(), place)) {
69-
out->ShareDataWith(in_tensor);
70-
} else {
71-
phi::Copy(dev_ctx, in_tensor, place, false, out);
72-
}
73-
74-
out->set_lod(in_tensor.lod());
75-
}
76-
7756
class FeedOp : public framework::OperatorWithKernel {
7857
using framework::OperatorWithKernel::OperatorWithKernel;
7958

@@ -164,6 +143,3 @@ REGISTER_OPERATOR(
164143
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
165144
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
166145
paddle::operators::FeedOpInfoMaker);
167-
168-
PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(
169-
feed, ALL_LAYOUT, paddle::operators::FeedDenseTensorKernel) {}

paddle/fluid/operators/generator/get_expected_kernel_func.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,4 +447,22 @@ phi::KernelKey GetMulticlassNmsExpectedKernelType(
447447
phi::CPUPlace());
448448
}
449449

450+
phi::KernelKey GetFeedExpectedKernelType(
451+
const framework::ExecutionContext& ctx,
452+
const framework::OperatorWithKernel* op_ptr) {
453+
const framework::Variable* x_var = ctx.InputVar("X");
454+
auto& x = x_var->Get<framework::FeedList>();
455+
int col = ctx.Attr<int>("col");
456+
auto& feed_item = x[col];
457+
458+
framework::proto::VarType::Type expected_data_type;
459+
if (feed_item.index() == 0) { // DenseTensor
460+
expected_data_type = framework::TransToProtoVarType(
461+
PADDLE_GET_CONST(phi::DenseTensor, feed_item).dtype());
462+
} else {
463+
expected_data_type = framework::proto::VarType::FP32;
464+
}
465+
466+
return phi::KernelKey(expected_data_type, ctx.GetPlace());
467+
}
450468
} // namespace paddle::operators

paddle/fluid/operators/generator/get_expected_kernel_func.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,5 +112,9 @@ phi::KernelKey GetMulticlassNmsExpectedKernelType(
112112
const framework::ExecutionContext& ctx,
113113
const framework::OperatorWithKernel* op_ptr);
114114

115+
phi::KernelKey GetFeedExpectedKernelType(
116+
const framework::ExecutionContext& ctx,
117+
const framework::OperatorWithKernel* op_ptr);
118+
115119
} // namespace operators
116120
} // namespace paddle

paddle/fluid/pir/dialect/op_generator/ops_api_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@
176176
'dgc',
177177
'dpsgd',
178178
'embedding_grad_sparse',
179+
'feed',
179180
'faster_tokenizer',
180181
'ftrl',
181182
'fused_adam_',

paddle/phi/core/utils/type_info.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ limitations under the License. */
1414

1515
#include <string>
1616

17+
#include "paddle/fluid/framework/feed_fetch_type.h"
1718
#include "paddle/phi/backends/cpu/cpu_context.h"
1819
#include "paddle/phi/backends/custom/custom_context.h"
1920
#include "paddle/phi/backends/gpu/gpu_context.h"
@@ -51,6 +52,7 @@ template class TypeInfoTraits<phi::TensorBase, SparseCsrTensor>;
5152
template class TypeInfoTraits<phi::TensorBase, StringTensor>;
5253
template class TypeInfoTraits<phi::TensorBase, TensorArray>;
5354
template class TypeInfoTraits<phi::TensorBase, phi::distributed::DistTensor>;
55+
template class TypeInfoTraits<phi::TensorBase, paddle::framework::FeedList>;
5456
template class TypeInfoTraits<phi::TensorBase, Vocab>;
5557
template class TypeInfoTraits<phi::TensorBase, Strings>;
5658

paddle/phi/infermeta/unary.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1439,6 +1439,8 @@ void FillAnyLikeInferMeta(const MetaTensor& x,
14391439
out->share_lod(x);
14401440
}
14411441

1442+
void FeedInferMeta(MetaTensor* out) {}
1443+
14421444
void FetchBarrierInferMeta(const std::vector<const MetaTensor*>& x,
14431445
int trainer_id,
14441446
const std::vector<std::string>& endpoints,

paddle/phi/infermeta/unary.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,8 @@ void FakeQuantizeAbsMaxInferMeta(const MetaTensor& x,
267267
MetaTensor* out,
268268
MetaTensor* out_scale);
269269

270+
void FeedInferMeta(MetaTensor* out);
271+
270272
void FetchBarrierInferMeta(const std::vector<const MetaTensor*>& x,
271273
int trainer_id,
272274
const std::vector<std::string>& endpoints,
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/feed_fetch_type.h"
16+
#include "paddle/phi/core/kernel_registry.h"
17+
#include "paddle/phi/core/tensor_utils.h"
18+
19+
namespace phi {
20+
21+
const paddle::framework::FeedType& CheckAndGetFeedItem(
22+
const phi::ExtendedTensor& x, int col) {
23+
PADDLE_ENFORCE_GE(col,
24+
0,
25+
common::errors::InvalidArgument(
26+
"Expected the column index (the attribute 'col' of "
27+
"operator 'Feed') of current feeding variable to be "
28+
"no less than 0. But received column index = %d.",
29+
col));
30+
auto feed_list = static_cast<const paddle::framework::FeedList*>(&x);
31+
PADDLE_ENFORCE_LT(
32+
static_cast<size_t>(col),
33+
feed_list->size(),
34+
common::errors::InvalidArgument(
35+
"The column index of current feeding variable is expected to be "
36+
"less than the length of feeding list. But received column index = "
37+
"%d, the length of feeding list = %d",
38+
col,
39+
feed_list->size()));
40+
41+
return feed_list->at(static_cast<size_t>(col));
42+
}
43+
44+
template <typename Context>
45+
void FeedDenseTensorKernel(const Context& dev_ctx,
46+
const phi::ExtendedTensor& x,
47+
int col,
48+
phi::DenseTensor* out) {
49+
PADDLE_ENFORCE_NOT_NULL(
50+
out,
51+
common::errors::NotFound(
52+
"Output cannot be found in scope for operator 'Feed'"));
53+
const auto& feed_item = CheckAndGetFeedItem(x, col);
54+
const auto& in_tensor = paddle::get<phi::DenseTensor>(feed_item);
55+
const auto& place = dev_ctx.GetPlace();
56+
if (phi::is_same_place(in_tensor.place(), place)) {
57+
out->ShareDataWith(in_tensor);
58+
} else {
59+
phi::Copy(dev_ctx, in_tensor, place, false, out);
60+
}
61+
62+
out->set_lod(in_tensor.lod());
63+
}
64+
65+
} // namespace phi
66+
67+
PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(feed,
68+
ALL_LAYOUT,
69+
phi::FeedDenseTensorKernel) {}

paddle/phi/ops/yaml/inconsistent/static_ops.yaml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,8 +308,13 @@
308308
traits : paddle::dialect::ForwardOnlyTrait
309309

310310
- op : feed
311-
args : (str name, int col)
311+
args : (Tensor x, int col)
312312
output : Tensor(out)
313+
infer_meta :
314+
func : FeedInferMeta
315+
param: []
316+
kernel :
317+
func : feed
313318
interfaces : paddle::dialect::InferSymbolicShapeInterface
314319
traits: pir::ImmutableLayoutTrait
315320

0 commit comments

Comments
 (0)