Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 84 additions & 3 deletions paddle/fluid/operators/elementwise/elementwise_div_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ limitations under the License. */

#include "paddle/fluid/operators/elementwise/elementwise_div_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

头文件已经删除


Expand All @@ -33,7 +35,9 @@ static __global__ void SimpleElemwiseDivGradCUDAKernel(const T* x, const T* y,

while (col < size) {
T o = dout[col];
dx[col] = o / y[col];
if (dx != nullptr) {
dx[col] = o / y[col];
}
dy[col] = -o * out[col] / y[col];
col += blockDim.x * gridDim.x;
}
Expand All @@ -55,7 +59,9 @@ SimpleElemwiseDivGradCUDAKernel<paddle::platform::complex<float>>(
paddle::platform::complex<float> y_conj(y[col].real, -y[col].imag);
paddle::platform::complex<float> out_div_y_conj((out[col] / y[col]).real,
-(out[col] / y[col]).imag);
dx[col] = o / y_conj;
if (dx != nullptr) {
dx[col] = o / y_conj;
}
dy[col] = -o * out_div_y_conj;
col += blockDim.x * gridDim.x;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这种写法可以修改成为 grid_stride的写法,见链接:https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

}
Expand All @@ -77,12 +83,87 @@ SimpleElemwiseDivGradCUDAKernel<paddle::platform::complex<double>>(
paddle::platform::complex<double> y_conj(y[col].real, -y[col].imag);
paddle::platform::complex<double> out_div_y_conj((out[col] / y[col]).real,
-(out[col] / y[col]).imag);
dx[col] = o / y_conj;
if (dx != nullptr) {
dx[col] = o / y_conj;
}
dy[col] = -o * out_div_y_conj;
col += blockDim.x * gridDim.x;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

}
}

template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
default_elementwise_div_grad(const framework::ExecutionContext& ctx,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

const framework::Tensor* x,
const framework::Tensor* y,
const framework::Tensor* out,
const framework::Tensor* dout,
framework::Tensor* dx, framework::Tensor* dy) {
int axis = ctx.Attr<int>("axis");
auto* dout_data = dout->data<T>();
dim3 block_size = dim3(ELEMENTWISE_BLOCK_SIZE, 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

block_size 定义了但没有被使用

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经删掉

// dx
if (dx != nullptr) {
auto* dx_data = dx->mutable_data<T>(ctx.GetPlace());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mutable_data的结果不必传给指针(下文没用到指针),下同

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

// For inplace strategy, dx will be stored in addr of dout, which makes
// the result of dy wrong.
if (dx->IsSharedBufferWith(*dout)) {
dx->clear();
dx->mutable_data<T>(x->dims(), ctx.GetPlace());
}
if (dx->dims() == dout->dims()) {
// dx = dout/y
ElementwiseComputeEx<DivGradFunctor<T>, DeviceContext, T>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

ctx, dout, y, axis, DivGradFunctor<T>(), dx);
} else {
framework::Tensor tmp_dx;
tmp_dx.Resize(dout->dims());

ElementwiseComputeEx<DivGradFunctor<T>, DeviceContext, T>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

纯GPU代码就不要调用这个接口了,这个接口是用于同时需要支持CPU 和 GPU计算的时候才用的,纯粹GPU的代码还是走LaunchElementwiseCudaKernel 更直观

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

ctx, dout, y, axis, DivGradFunctor<T>(), &tmp_dx);

std::vector<int> reduce_dims = GetReduceDim(x->dims(), out->dims(), axis);
gpuStream_t stream = ctx.cuda_device_context().stream();
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
tmp_dx, dx, kps::IdentityFunctor<T>(), reduce_dims, stream);
}
}
// dy
if (dy != nullptr) {
auto* dy_data = dy->mutable_data<T>(ctx.GetPlace());
if (dy->dims() == dout->dims()) {
if (dy_data != dout_data) {
// dy = - dout * out / y
auto size = dy->numel();
dim3 grid_size = dim3(
(size + ELEMENTWISE_BLOCK_SIZE - 1) / ELEMENTWISE_BLOCK_SIZE, 1);
SimpleElemwiseDivGradCUDAKernel<T><<<
grid_size, block_size, 0,
ctx.template device_context<plat::CUDADeviceContext>().stream()>>>(
x->data<T>(), y->data<T>(), out->data<T>(), dout->data<T>(), size,
nullptr, dy->mutable_data<T>(ctx.GetPlace()));
}
} else {
framework::Tensor tmp_dy;
tmp_dy.mutable_data<T>(dout->dims(), ctx.GetPlace());

std::vector<const framework::Tensor*> ins = {dout, out, y};
std::vector<framework::Tensor*> outs = {&tmp_dy};

const auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
LaunchElementwiseCudaKernel<ElementwiseType::kTernary, T, T>(
dev_ctx, ins, &outs, axis, DivGradYFunctor<T>());

std::vector<int> reduce_dims = GetReduceDim(y->dims(), out->dims(), axis);
gpuStream_t stream = ctx.cuda_device_context().stream();
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::InverseFunctor<T>>(
tmp_dy, dy, kps::InverseFunctor<T>(), reduce_dims, stream);
}
}
}

template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, plat::CUDADeviceContext>::value>::type
Expand Down
35 changes: 28 additions & 7 deletions paddle/fluid/operators/elementwise/elementwise_div_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,21 @@ struct DivDoubleDY {
}
};

template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
default_elementwise_div_grad(const framework::ExecutionContext& ctx,
const framework::Tensor* x,
const framework::Tensor* y,
const framework::Tensor* out,
const framework::Tensor* dout,
framework::Tensor* dx, framework::Tensor* dy) {
int axis = ctx.Attr<int>("axis");

ElemwiseGradCompute<DeviceContext, T, DivGradDX<T>, DivGradDY<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, DivGradDX<T>(), DivGradDY<T>());
}

template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
Expand All @@ -116,13 +131,21 @@ elementwise_div_grad(const framework::ExecutionContext& ctx,
const framework::Tensor* out,
const framework::Tensor* dout, framework::Tensor* dx,
framework::Tensor* dy) {
int axis = ctx.Attr<int>("axis");
ElemwiseGradCompute<DeviceContext, T, DivGradDX<T>, DivGradDY<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, DivGradDX<T>(), DivGradDY<T>());
default_elementwise_div_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
}

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
template <typename DeviceContext, typename T>
// cuda definition
typename std::enable_if<
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
default_elementwise_div_grad(const framework::ExecutionContext& ctx,
const framework::Tensor* x,
const framework::Tensor* y,
const framework::Tensor* out,
const framework::Tensor* dout,
framework::Tensor* dx, framework::Tensor* dy);

template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
Expand All @@ -146,14 +169,12 @@ class ElementwiseDivGradKernel : public ElemwiseGradKernel<T> {
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis");

if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DefaultElementwiseDivGrad已经包括这个分支了,可以删除

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

elementwise_div_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
} else {
ElemwiseGradCompute<DeviceContext, T, DivGradDX<T>, DivGradDY<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, DivGradDX<T>(),
DivGradDY<T>());
default_elementwise_div_grad<DeviceContext, T>(ctx, x, y, out, dout, dx,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default也改个名字吧,比如改成Common,或者其他更好的

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后续会统一修改

dy);
}
}
};
Expand Down
55 changes: 55 additions & 0 deletions paddle/fluid/operators/elementwise/elementwise_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License. */

#pragma once

#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/hostdevice.h"
Expand Down Expand Up @@ -113,6 +114,60 @@ struct MinFunctor {
}
};

// Float mul grad
template <typename T>
struct MulGradFunctor {
inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a * b; }
};

// Complex mul grad
template <typename T>
struct MulGradFunctor<paddle::platform::complex<T>> {
inline HOSTDEVICE paddle::platform::complex<T> operator()(
const paddle::platform::complex<T>& a,
const paddle::platform::complex<T>& b) const {
paddle::platform::complex<T> b_conj(b.real, -b.imag);
return a * b_conj;
}
};

// Float div grad
template <typename T>
struct DivGradFunctor {
inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a / b; }
};

// Complex div grad
template <typename T>
struct DivGradFunctor<paddle::platform::complex<T>> {
inline HOSTDEVICE paddle::platform::complex<T> operator()(
const paddle::platform::complex<T>& a,
const paddle::platform::complex<T>& b) const {
paddle::platform::complex<T> b_conj(b.real, -b.imag);
return a / b_conj;
}
};

// Float mul and div
template <typename T>
struct DivGradYFunctor {
inline HOSTDEVICE T operator()(const T& a, const T& b, const T& c) const {
return a * b / c;
}
};

// Complex mul and div
template <typename T>
struct DivGradYFunctor<paddle::platform::complex<T>> {
inline HOSTDEVICE paddle::platform::complex<T> operator()(
const paddle::platform::complex<T>& a,
const paddle::platform::complex<T>& b,
const paddle::platform::complex<T>& c) const {
paddle::platform::complex<T> c_conj(c.real, -c.imag);
return a * b / c_conj;
}
};

// Fmax
template <typename T>
struct FMaxFunctor {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -432,8 +432,8 @@ void LaunchBroadcastElementwiseCudaKernel(
"is %d, the arity of functor is %d.",
ins.size(),
kArity));
PADDLE_ENFORCE_EQ(kArity,
2,
PADDLE_ENFORCE_LE(kArity,
3,
paddle::platform::errors::InvalidArgument(
"Currently only broadcast of binary is supported and "
"verified, but received %d.",
Expand Down