Skip to content

Commit 47b5b94

Browse files
committed
Fix
1 parent 2275419 commit 47b5b94

File tree

4 files changed

+71
-25
lines changed

4 files changed

+71
-25
lines changed

paddle/fluid/framework/type_info.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ bool TypeInfoTraits<BaseT, DerivedT>::classof(const BaseT* obj) {
3939
}
4040

4141
template class TypeInfoTraits<phi::TensorBase, paddle::framework::RawTensor>;
42-
template class TypeInfoTraits<phi::TensorBase, paddle::framework::FeedList>;
4342
template class TypeInfoTraits<phi::TensorBase, egr::VariableCompatTensor>;
4443
template class TypeInfoTraits<phi::TensorBase, paddle::prim::DescTensor>;
4544
template class TypeInfoTraits<phi::TensorBase, paddle::primitive::LazyTensor>;

paddle/fluid/operators/controlflow/feed_op.cc

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -53,27 +53,6 @@ const framework::FeedType& CheckAndGetFeedItem(const phi::ExtendedTensor& x,
5353
return feed_list->at(static_cast<size_t>(col));
5454
}
5555

56-
template <typename Context>
57-
void FeedDenseTensorKernel(const Context& dev_ctx,
58-
const phi::ExtendedTensor& x,
59-
int col,
60-
phi::DenseTensor* out) {
61-
PADDLE_ENFORCE_NOT_NULL(
62-
out,
63-
common::errors::NotFound(
64-
"Output cannot be found in scope for operator 'Feed'"));
65-
const auto& feed_item = CheckAndGetFeedItem(x, col);
66-
const auto& in_tensor = paddle::get<phi::DenseTensor>(feed_item);
67-
const auto& place = dev_ctx.GetPlace();
68-
if (phi::is_same_place(in_tensor.place(), place)) {
69-
out->ShareDataWith(in_tensor);
70-
} else {
71-
phi::Copy(dev_ctx, in_tensor, place, false, out);
72-
}
73-
74-
out->set_lod(in_tensor.lod());
75-
}
76-
7756
class FeedOp : public framework::OperatorWithKernel {
7857
using framework::OperatorWithKernel::OperatorWithKernel;
7958

@@ -164,6 +143,3 @@ REGISTER_OPERATOR(
164143
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
165144
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
166145
paddle::operators::FeedOpInfoMaker);
167-
168-
PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(
169-
feed, ALL_LAYOUT, paddle::operators::FeedDenseTensorKernel) {}

paddle/phi/core/utils/type_info.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ limitations under the License. */
1414

1515
#include <string>
1616

17+
#include "paddle/fluid/framework/feed_fetch_type.h"
1718
#include "paddle/phi/backends/cpu/cpu_context.h"
1819
#include "paddle/phi/backends/custom/custom_context.h"
1920
#include "paddle/phi/backends/gpu/gpu_context.h"
@@ -51,6 +52,7 @@ template class TypeInfoTraits<phi::TensorBase, SparseCsrTensor>;
5152
template class TypeInfoTraits<phi::TensorBase, StringTensor>;
5253
template class TypeInfoTraits<phi::TensorBase, TensorArray>;
5354
template class TypeInfoTraits<phi::TensorBase, phi::distributed::DistTensor>;
55+
template class TypeInfoTraits<phi::TensorBase, paddle::framework::FeedList>;
5456
template class TypeInfoTraits<phi::TensorBase, Vocab>;
5557
template class TypeInfoTraits<phi::TensorBase, Strings>;
5658

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/core/framework/feed_fetch_type.h"
16+
#include "paddle/phi/core/kernel_registry.h"
17+
#include "paddle/phi/core/tensor_utils.h"
18+
19+
namespace phi {
20+
21+
const phi::FeedType& CheckAndGetFeedItem(const phi::ExtendedTensor& x,
22+
int col) {
23+
PADDLE_ENFORCE_GE(col,
24+
0,
25+
common::errors::InvalidArgument(
26+
"Expected the column index (the attribute 'col' of "
27+
"operator 'Feed') of current feeding variable to be "
28+
"no less than 0. But received column index = %d.",
29+
col));
30+
auto feed_list = static_cast<const phi::FeedList*>(&x);
31+
PADDLE_ENFORCE_LT(
32+
static_cast<size_t>(col),
33+
feed_list->size(),
34+
common::errors::InvalidArgument(
35+
"The column index of current feeding variable is expected to be "
36+
"less than the length of feeding list. But received column index = "
37+
"%d, the length of feeding list = %d",
38+
col,
39+
feed_list->size()));
40+
41+
return feed_list->at(static_cast<size_t>(col));
42+
}
43+
44+
template <typename Context>
45+
void FeedDenseTensorKernel(const Context& dev_ctx,
46+
const phi::ExtendedTensor& x,
47+
int col,
48+
phi::DenseTensor* out) {
49+
PADDLE_ENFORCE_NOT_NULL(
50+
out,
51+
common::errors::NotFound(
52+
"Output cannot be found in scope for operator 'Feed'"));
53+
const auto& feed_item = CheckAndGetFeedItem(x, col);
54+
const auto& in_tensor = paddle::get<phi::DenseTensor>(feed_item);
55+
const auto& place = dev_ctx.GetPlace();
56+
if (phi::is_same_place(in_tensor.place(), place)) {
57+
out->ShareDataWith(in_tensor);
58+
} else {
59+
phi::Copy(dev_ctx, in_tensor, place, false, out);
60+
}
61+
62+
out->set_lod(in_tensor.lod());
63+
}
64+
65+
} // namespace phi
66+
67+
PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(feed,
68+
ALL_LAYOUT,
69+
phi::FeedDenseTensorKernel) {}

0 commit comments

Comments
 (0)