Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 1 addition & 17 deletions paddle/fluid/framework/feed_fetch_method.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,29 +69,13 @@ void SetFeedVariable(Scope* scope,
feed_inputs.resize(index + 1);
}
// shared data with input tensor
auto& val = PADDLE_GET(phi::DenseTensor, feed_inputs[index]);
auto& val = feed_inputs[index];
val.ShareDataWith(input);
// set lod
val.set_lod(input.lod());
}
}

void SetFeedVariable(Scope* scope,
const std::vector<std::string>& input,
const std::string& var_name,
size_t index) {
// If var_name Variable is not found in GlobalScope, a new variable will
// be created.
VLOG(3) << "SetFeedStringVariable name=" << var_name << " index=" << index;
Variable* g_feed_value = scope->Var(var_name);
auto& feed_inputs = *(g_feed_value->GetMutable<FeedList>());
if (index >= feed_inputs.size()) {
feed_inputs.resize(index + 1);
}
// shared data with input tensor
feed_inputs[index] = Strings(input);
}

FetchType& GetFetchVariable(const Scope& scope,
const std::string& var_name,
size_t index) {
Expand Down
5 changes: 0 additions & 5 deletions paddle/fluid/framework/feed_fetch_method.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,6 @@ void SetFeedVariable(Scope* scope,
const std::string& var_name,
size_t index);

void SetFeedVariable(Scope* scope,
const std::vector<std::string>& input,
const std::string& var_name,
size_t index);

FetchType& GetFetchVariable(const Scope& scope,
const std::string& var_name,
size_t index);
Expand Down
7 changes: 0 additions & 7 deletions paddle/fluid/framework/feed_fetch_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,6 @@ inline bool data_is_lod_tensor_array(const FetchType &data) {
return false;
}

inline bool data_is_string_tensor(const FeedType &data) {
if (data.type() == typeid(Strings)) {
return true;
}
return false;
}

inline bool data_is_sparse_coo_tensor(const FetchType &data) {
if (data.type() == typeid(phi::SparseCooTensor)) {
return true;
Expand Down
1 change: 0 additions & 1 deletion paddle/fluid/framework/type_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ bool TypeInfoTraits<BaseT, DerivedT>::classof(const BaseT* obj) {
}

template class TypeInfoTraits<phi::TensorBase, paddle::framework::RawTensor>;
template class TypeInfoTraits<phi::TensorBase, paddle::framework::FeedList>;
template class TypeInfoTraits<phi::TensorBase, egr::VariableCompatTensor>;
template class TypeInfoTraits<phi::TensorBase, paddle::prim::DescTensor>;
template class TypeInfoTraits<phi::TensorBase, paddle::primitive::LazyTensor>;
Expand Down
61 changes: 12 additions & 49 deletions paddle/fluid/operators/controlflow/feed_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,27 +53,6 @@ const framework::FeedType& CheckAndGetFeedItem(const phi::ExtendedTensor& x,
return feed_list->at(static_cast<size_t>(col));
}

template <typename Context>
void FeedDenseTensorKernel(const Context& dev_ctx,
const phi::ExtendedTensor& x,
int col,
phi::DenseTensor* out) {
PADDLE_ENFORCE_NOT_NULL(
out,
common::errors::NotFound(
"Output cannot be found in scope for operator 'Feed'"));
const auto& feed_item = CheckAndGetFeedItem(x, col);
const auto& in_tensor = paddle::get<phi::DenseTensor>(feed_item);
const auto& place = dev_ctx.GetPlace();
if (phi::is_same_place(in_tensor.place(), place)) {
out->ShareDataWith(in_tensor);
} else {
phi::Copy(dev_ctx, in_tensor, place, false, out);
}

out->set_lod(in_tensor.lod());
}

class FeedOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;

Expand All @@ -90,23 +69,18 @@ class FeedOp : public framework::OperatorWithKernel {
int col = ctx->Attrs().Get<int>("col");
const auto& feed_item = CheckAndGetFeedItem(x, col);

if (feed_item.index() == 0) { // DenseTensor
auto& feed_tensor = PADDLE_GET_CONST(phi::DenseTensor, feed_item);
phi::DenseTensor* out_tensor = out_var->GetMutable<phi::DenseTensor>();
phi::DenseTensorMeta meta = out_tensor->meta();
meta.dims = feed_tensor.dims();
meta.dtype = feed_tensor.dtype();
meta.layout = feed_tensor.layout();
meta.lod = feed_tensor.lod();
meta.strides = feed_tensor.strides();
if (meta.strides.size() == -1) {
meta.strides = meta.calc_strides(meta.dims);
}
out_tensor->set_meta(meta);
} else {
PADDLE_THROW(common::errors::Unimplemented(
"Only support DenseTensor for feed op now."));
auto& feed_tensor = feed_item;
phi::DenseTensor* out_tensor = out_var->GetMutable<phi::DenseTensor>();
phi::DenseTensorMeta meta = out_tensor->meta();
meta.dims = feed_tensor.dims();
meta.dtype = feed_tensor.dtype();
meta.layout = feed_tensor.layout();
meta.lod = feed_tensor.lod();
meta.strides = feed_tensor.strides();
if (meta.strides.size() == -1) {
meta.strides = meta.calc_strides(meta.dims);
}
out_tensor->set_meta(meta);
}
}

Expand All @@ -119,15 +93,7 @@ class FeedOp : public framework::OperatorWithKernel {
auto& feed_item = x[col];

framework::proto::VarType::Type expected_data_type;
if (feed_item.index() == 0) { // DenseTensor
expected_data_type = framework::TransToProtoVarType(
PADDLE_GET_CONST(phi::DenseTensor, feed_item).dtype());
} else if (feed_item.index() == 2) { // SparseCooTensor
expected_data_type = framework::TransToProtoVarType(
PADDLE_GET_CONST(phi::SparseCooTensor, feed_item).dtype());
} else { // Strings
expected_data_type = framework::proto::VarType::FP32;
}
expected_data_type = framework::TransToProtoVarType(feed_item.dtype());

return phi::KernelKey(expected_data_type, ctx.GetPlace());
}
Expand Down Expand Up @@ -164,6 +130,3 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
paddle::operators::FeedOpInfoMaker);

PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(
feed, ALL_LAYOUT, paddle::operators::FeedDenseTensorKernel) {}
6 changes: 0 additions & 6 deletions paddle/fluid/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2474,12 +2474,6 @@ All parameter, weight, gradient are variables in Paddle.
const phi::DenseTensor &,
const std::string &,
size_t)>(&framework::SetFeedVariable));
m.def("set_feed_variable",
static_cast<void (*)( // NOLINT
Scope *,
const std::vector<std::string> &,
const std::string &,
size_t)>(&framework::SetFeedVariable));
m.def("get_fetch_variable",
[](const Scope &scope,
const std::string &var_name,
Expand Down
3 changes: 1 addition & 2 deletions paddle/phi/core/framework/feed_fetch_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@
#include "paddle/phi/core/vocab/string_array.h"

namespace phi {
using FeedType =
paddle::variant<phi::DenseTensor, phi::Strings, phi::SparseCooTensor>;
using FeedType = phi::DenseTensor;
using FetchType = paddle::variant<phi::DenseTensor,
phi::TensorArray,
phi::Vocab,
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/core/kernel_registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,12 @@ void SetKernelArgsDef(const std::vector<std::type_index>& args_type,
default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type ==
std::type_index(typeid(const phi::FeedList&))) { // NOLINT
args_def->AppendInput(default_key.backend(),
default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type == std::type_index(typeid(
const paddle::optional<Strings>&))) { // NOLINT
args_def->AppendInput(default_key.backend(),
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/core/kernel_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/extended_tensor.h"
#include "paddle/phi/core/framework/feed_fetch_type.h"
#include "paddle/phi/core/kernel_context.h"
#include "paddle/phi/core/selected_rows.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
Expand Down Expand Up @@ -344,6 +345,7 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {

PD_SPECIALIZE_KernelCallHelper_FOR_INPUT(phi::Strings);
PD_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(phi::Strings);
PD_SPECIALIZE_KernelCallHelper_FOR_INPUT(phi::FeedList);

/* Attribute Helpers */

Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/core/utils/type_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#include "paddle/phi/core/framework/feed_fetch_type.h"
#include "paddle/phi/core/selected_rows.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
Expand Down Expand Up @@ -53,6 +54,7 @@ template class TypeInfoTraits<phi::TensorBase, TensorArray>;
template class TypeInfoTraits<phi::TensorBase, phi::distributed::DistTensor>;
template class TypeInfoTraits<phi::TensorBase, Vocab>;
template class TypeInfoTraits<phi::TensorBase, Strings>;
template class TypeInfoTraits<phi::TensorBase, FeedList>;

template class TypeInfoTraits<phi::DeviceContext, CPUContext>;
template class TypeInfoTraits<phi::DeviceContext, CustomContext>;
Expand Down
69 changes: 69 additions & 0 deletions paddle/phi/kernels/cpu/feed_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/core/framework/feed_fetch_type.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"

namespace phi {

const phi::FeedType& CheckAndGetFeedItem(const phi::FeedList* x, int col) {
PADDLE_ENFORCE_GE(col,
0,
common::errors::InvalidArgument(
"Expected the column index (the attribute 'col' of "
"operator 'Feed') of current feeding variable to be "
"no less than 0. But received column index = %d.",
col));
const auto feed_list = x;
PADDLE_ENFORCE_LT(
static_cast<size_t>(col),
feed_list->size(),
common::errors::InvalidArgument(
"The column index of current feeding variable is expected to be "
"less than the length of feeding list. But received column index = "
"%d, the length of feeding list = %d",
col,
feed_list->size()));

return feed_list->at(static_cast<size_t>(col));
}

template <typename Context>
void FeedDenseTensorKernel(const Context& dev_ctx,
const phi::ExtendedTensor& x,
int col,
phi::DenseTensor* out) {
PADDLE_ENFORCE_NOT_NULL(
out,
common::errors::NotFound(
"Output cannot be found in scope for operator 'Feed'"));
const auto& feed_item =
CheckAndGetFeedItem(reinterpret_cast<const phi::FeedList*>(&x), col);
const auto& in_tensor = static_cast<DenseTensor>(feed_item);
const auto& place = dev_ctx.GetPlace();
if (phi::is_same_place(in_tensor.place(), place)) {
out->ShareDataWith(in_tensor);
} else {
phi::Copy(dev_ctx, in_tensor, place, false, out);
}

out->set_lod(in_tensor.lod());
}

} // namespace phi

PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(feed,
ALL_LAYOUT,
phi::FeedDenseTensorKernel) {}