Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions paddle/fluid/operators/lookup_table_v2_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -203,14 +203,6 @@ REGISTER_OPERATOR(lookup_table_v2_grad, ops::LookupTableV2OpGrad,
ops::LookupTableV2GradOpNoBufferVarsInferer,
ops::LookupTableV2OpGradVarTypeInference);

REGISTER_OP_CPU_KERNEL(lookup_table_v2, ops::LookupTableV2Kernel<float>,
ops::LookupTableV2Kernel<double>,
ops::LookupTableV2Kernel<paddle::platform::bfloat16>);
REGISTER_OP_CPU_KERNEL(
lookup_table_v2_grad, ops::LookupTableV2GradKernel<float>,
ops::LookupTableV2GradKernel<double>,
ops::LookupTableV2GradKernel<paddle::platform::bfloat16>);

/* ========================== register checkpoint ===========================*/
REGISTER_OP_VERSION(lookup_table_v2)
.AddCheckpoint(
Expand Down
10 changes: 0 additions & 10 deletions paddle/fluid/operators/lookup_table_v2_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -235,13 +235,3 @@ class LookupTableV2GradCUDAKernel : public framework::OpKernel<T> {

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(lookup_table_v2, ops::LookupTableV2CUDAKernel<float>,
ops::LookupTableV2CUDAKernel<double>,
ops::LookupTableV2CUDAKernel<plat::float16>);
REGISTER_OP_CUDA_KERNEL(lookup_table_v2_grad,
ops::LookupTableV2GradCUDAKernel<float>,
ops::LookupTableV2GradCUDAKernel<double>,
ops::LookupTableV2GradCUDAKernel<plat::float16>);
220 changes: 220 additions & 0 deletions paddle/phi/kernels/cpu/embedding_grad_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/embedding_grad_kernel.h"
#include "paddle/phi/kernels/funcs/embedding_util.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_registry.h"

namespace phi {

template <typename T, typename Context>
struct EmbeddingGradCPUFunctor {
EmbeddingGradCPUFunctor(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& weight,
const DenseTensor& out_grad,
int64_t padding_idx,
DenseTensor* weight_grad)
: dev_ctx_(dev_ctx),
input_(input),
weight_(weight),
out_grad_(out_grad),
weight_grad_(weight_grad),
padding_idx_(padding_idx) {}

template <typename IdT>
void apply() {
DDim table_dim = weight_.dims();

auto ids = CopyIdsToVector<IdT, int64_t>(input_);
auto ids_num = static_cast<int64_t>(ids.size());

// Since paddings are not trainable and fixed in forward, the gradient of
// paddings makes no sense and we don't deal with it in backward.
{
auto* d_output = &out_grad_;
auto* ids_data = ids.data();

int64_t N = table_dim[0];
int64_t D = table_dim[1];

auto* d_output_data = d_output->template data<T>();

dev_ctx_.template Alloc<T>(weight_grad_);
auto* d_table_data = weight_grad_->data<T>();

memset(d_table_data, 0, weight_grad_->numel() * sizeof(T));

for (int64_t i = 0; i < ids_num; ++i) {
if (padding_idx_ != kNoPadding && ids_data[i] == padding_idx_) {
// the gradient of padding_idx should be 0, already done by memset, so
// do nothing.
} else {
PADDLE_ENFORCE_LT(
ids_data[i],
N,
phi::errors::InvalidArgument(
"Variable value (input) of "
"OP(paddle.nn.functional.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
N,
ids_data[i]));
PADDLE_ENFORCE_GE(
ids_data[i],
0,
phi::errors::InvalidArgument(
"Variable value (input) of "
"OP(paddle.nn.functional.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
N,
ids_data[i]));
for (int j = 0; j < D; ++j) {
d_table_data[ids_data[i] * D + j] += d_output_data[i * D + j];
}
}
}
}
}

private:
const Context& dev_ctx_;
const DenseTensor& input_;
const DenseTensor& weight_;
const DenseTensor& out_grad_;
DenseTensor* weight_grad_;
int64_t padding_idx_;
};

template <typename T, typename Context>
void EmbeddingGradKernel(const Context& ctx,
const DenseTensor& input,
const DenseTensor& weight,
const DenseTensor& out_grad,
int64_t padding_idx,
DenseTensor* weight_grad) {
EmbeddingGradCPUFunctor<T, Context> functor(
ctx, input, weight, out_grad, padding_idx, weight_grad);
if (input.dtype() == phi::DataType::INT32) {
functor.template apply<int>();
} else if (input.dtype() == phi::DataType::INT64) {
functor.template apply<int64_t>();
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"emebdding input only support int32 and int64"));
}
}

template <typename T, typename Context>
struct EmbeddingSparseGradCPUFunctor {
EmbeddingSparseGradCPUFunctor(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& weight,
const DenseTensor& out_grad,
int64_t padding_idx,
SelectedRows* weight_grad)
: dev_ctx_(dev_ctx),
input_(input),
weight_(weight),
out_grad_(out_grad),
weight_grad_(weight_grad),
padding_idx_(padding_idx) {}

template <typename IdT>
void apply() {
DDim table_dim = weight_.dims();

auto ids = CopyIdsToVector<IdT, int64_t>(input_);
auto ids_num = static_cast<int64_t>(ids.size());

// Since paddings are not trainable and fixed in forward, the gradient of
// paddings makes no sense and we don't deal with it in backward.
auto* d_table = weight_grad_;
auto* d_output = &out_grad_;
d_table->set_rows(ids);

auto* d_table_value = d_table->mutable_value();
d_table_value->Resize({ids_num, table_dim[1]});

dev_ctx_.template Alloc<T>(d_table_value);

d_table->set_height(table_dim[0]);

auto* d_output_data = d_output->template data<T>();
auto* d_table_data = d_table_value->template data<T>();

auto d_output_dims = d_output->dims();
auto d_output_dims_2d =
flatten_to_2d(d_output_dims, d_output_dims.size() - 1);
PADDLE_ENFORCE_EQ(d_table_value->dims(),
d_output_dims_2d,
phi::errors::InvalidArgument(
"ShapeError: The shape of lookup_table@Grad and "
"output@Grad should be same. "
"But received lookup_table@Grad's shape = [%s], "
"output@Grad's shape = [%s].",
d_table_value->dims(),
d_output_dims_2d));
memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel());
}

private:
const Context& dev_ctx_;
const DenseTensor& input_;
const DenseTensor& weight_;
const DenseTensor& out_grad_;
SelectedRows* weight_grad_;
int64_t padding_idx_;
};

template <typename T, typename Context>
void EmbeddingSparseGradKernel(const Context& ctx,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个Kernel可以放到selected_rows下

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

selected rows的拆分,单独用一个pr来做

const DenseTensor& input,
const DenseTensor& weight,
const DenseTensor& out_grad,
int64_t padding_idx,
SelectedRows* weight_grad) {
EmbeddingSparseGradCPUFunctor<T, Context> functor(
ctx, input, weight, out_grad, padding_idx, weight_grad);
if (input.dtype() == phi::DataType::INT32) {
functor.template apply<int>();
} else if (input.dtype() == phi::DataType::INT64) {
functor.template apply<int64_t>();
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"emebdding input only support int32 and int64"));
}
}

} // namespace phi

PD_REGISTER_KERNEL(embedding_grad,
CPU,
ALL_LAYOUT,
phi::EmbeddingGradKernel,
float,
double,
phi::dtype::bfloat16) {}

PD_REGISTER_KERNEL(embedding_sparse_grad,
CPU,
ALL_LAYOUT,
phi::EmbeddingSparseGradKernel,
float,
double,
phi::dtype::bfloat16) {}
114 changes: 114 additions & 0 deletions paddle/phi/kernels/cpu/embedding_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/embedding_kernel.h"
#include "paddle/phi/kernels/funcs/embedding_util.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"

namespace phi {

template <typename T, typename Context>
struct EmbeddingCPUFunctor {
EmbeddingCPUFunctor(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& weight,
int64_t padding_idx,
DenseTensor* out)
: dev_ctx_(dev_ctx),
input_(input),
weight_(weight),
out_(out),
padding_idx_(padding_idx) {}

template <typename IdT>
void apply() {
auto ids = CopyIdsToVector<IdT, int64_t>(input_);
auto ids_numel = static_cast<int64_t>(ids.size());

int64_t row_number = weight_.dims()[0];
int64_t row_width = weight_.dims()[1];

auto* table = weight_.data<T>();

dev_ctx_.template Alloc<T>(out_);
auto* output = out_->data<T>();

for (int64_t i = 0; i < ids_numel; ++i) {
if (padding_idx_ != kNoPadding && ids[i] == padding_idx_) {
memset(output + i * row_width, 0, row_width * sizeof(T));
} else {
PADDLE_ENFORCE_LT(
ids[i],
row_number,
phi::errors::InvalidArgument(
"Variable value (input) of OP(fluid.layers.embedding) "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fluid.layers.embedding->paddle.nn.functional.embedding

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
row_number,
ids[i]));
PADDLE_ENFORCE_GE(
ids[i],
0,
phi::errors::InvalidArgument(
"Variable value (input) of OP(fluid.layers.embedding) "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fluid.layers.embedding->paddle.nn.functional.embedding

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
row_number,
ids[i]));
memcpy(output + i * row_width,
table + ids[i] * row_width,
row_width * sizeof(T));
}
}
}

private:
const Context& dev_ctx_;
const DenseTensor& input_;
const DenseTensor& weight_;
DenseTensor* out_;
int64_t padding_idx_;
};

template <typename T, typename Context>
void EmbeddingKernel(const Context& ctx,
const DenseTensor& input,
const DenseTensor& weight,
int64_t padding_idx,
DenseTensor* out) {
EmbeddingCPUFunctor<T, Context> functor(ctx, input, weight, padding_idx, out);

if (input.dtype() == phi::DataType::INT32) {
functor.template apply<int>();
} else if (input.dtype() == phi::DataType::INT64) {
functor.template apply<int64_t>();
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"emebdding input only support int32 and int64"));
}
}

} // namespace phi

PD_REGISTER_KERNEL(embedding,
CPU,
ALL_LAYOUT,
phi::EmbeddingKernel,
float,
double,
phi::dtype::bfloat16) {}
Loading