-
Couldn't load subscription status.
- Fork 5.9k
implementation of broadcast div backward by reduce #38044
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
d3173f8
c6cef2e
9265a8d
080bf95
f0f1cf3
8c43581
b1f58dc
e07e54e
3594f6b
7adf371
560ed45
2920824
476c797
8259c34
d2f3776
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,6 +14,8 @@ limitations under the License. */ | |
|
|
||
| #include "paddle/fluid/operators/elementwise/elementwise_div_op.h" | ||
| #include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" | ||
| #include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h" | ||
| #include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" | ||
| #include "paddle/fluid/platform/complex.h" | ||
| #include "paddle/fluid/platform/float16.h" | ||
|
|
||
|
|
@@ -33,7 +35,9 @@ static __global__ void SimpleElemwiseDivGradCUDAKernel(const T* x, const T* y, | |
|
|
||
| while (col < size) { | ||
| T o = dout[col]; | ||
| dx[col] = o / y[col]; | ||
| if (dx != nullptr) { | ||
| dx[col] = o / y[col]; | ||
| } | ||
| dy[col] = -o * out[col] / y[col]; | ||
| col += blockDim.x * gridDim.x; | ||
| } | ||
|
|
@@ -55,7 +59,9 @@ SimpleElemwiseDivGradCUDAKernel<paddle::platform::complex<float>>( | |
| paddle::platform::complex<float> y_conj(y[col].real, -y[col].imag); | ||
| paddle::platform::complex<float> out_div_y_conj((out[col] / y[col]).real, | ||
| -(out[col] / y[col]).imag); | ||
| dx[col] = o / y_conj; | ||
| if (dx != nullptr) { | ||
| dx[col] = o / y_conj; | ||
| } | ||
| dy[col] = -o * out_div_y_conj; | ||
| col += blockDim.x * gridDim.x; | ||
|
||
| } | ||
|
|
@@ -77,12 +83,87 @@ SimpleElemwiseDivGradCUDAKernel<paddle::platform::complex<double>>( | |
| paddle::platform::complex<double> y_conj(y[col].real, -y[col].imag); | ||
| paddle::platform::complex<double> out_div_y_conj((out[col] / y[col]).real, | ||
| -(out[col] / y[col]).imag); | ||
| dx[col] = o / y_conj; | ||
| if (dx != nullptr) { | ||
| dx[col] = o / y_conj; | ||
| } | ||
| dy[col] = -o * out_div_y_conj; | ||
| col += blockDim.x * gridDim.x; | ||
|
||
| } | ||
| } | ||
|
|
||
| template <typename DeviceContext, typename T> | ||
| typename std::enable_if< | ||
| std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type | ||
| default_elementwise_div_grad(const framework::ExecutionContext& ctx, | ||
|
||
| const framework::Tensor* x, | ||
| const framework::Tensor* y, | ||
| const framework::Tensor* out, | ||
| const framework::Tensor* dout, | ||
| framework::Tensor* dx, framework::Tensor* dy) { | ||
| int axis = ctx.Attr<int>("axis"); | ||
| auto* dout_data = dout->data<T>(); | ||
| dim3 block_size = dim3(ELEMENTWISE_BLOCK_SIZE, 1); | ||
|
||
| // dx | ||
| if (dx != nullptr) { | ||
| auto* dx_data = dx->mutable_data<T>(ctx.GetPlace()); | ||
|
||
| // For inplace strategy, dx will be stored in addr of dout, which makes | ||
| // the result of dy wrong. | ||
| if (dx->IsSharedBufferWith(*dout)) { | ||
| dx->clear(); | ||
| dx->mutable_data<T>(x->dims(), ctx.GetPlace()); | ||
| } | ||
| if (dx->dims() == dout->dims()) { | ||
| // dx = dout/y | ||
| ElementwiseComputeEx<DivGradFunctor<T>, DeviceContext, T>( | ||
|
||
| ctx, dout, y, axis, DivGradFunctor<T>(), dx); | ||
| } else { | ||
| framework::Tensor tmp_dx; | ||
| tmp_dx.Resize(dout->dims()); | ||
|
|
||
| ElementwiseComputeEx<DivGradFunctor<T>, DeviceContext, T>( | ||
|
||
| ctx, dout, y, axis, DivGradFunctor<T>(), &tmp_dx); | ||
|
|
||
| std::vector<int> reduce_dims = GetReduceDim(x->dims(), out->dims(), axis); | ||
| gpuStream_t stream = ctx.cuda_device_context().stream(); | ||
| TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>( | ||
| tmp_dx, dx, kps::IdentityFunctor<T>(), reduce_dims, stream); | ||
| } | ||
| } | ||
| // dy | ||
| if (dy != nullptr) { | ||
| auto* dy_data = dy->mutable_data<T>(ctx.GetPlace()); | ||
| if (dy->dims() == dout->dims()) { | ||
| if (dy_data != dout_data) { | ||
| // dy = - dout * out / y | ||
| auto size = dy->numel(); | ||
| dim3 grid_size = dim3( | ||
| (size + ELEMENTWISE_BLOCK_SIZE - 1) / ELEMENTWISE_BLOCK_SIZE, 1); | ||
| SimpleElemwiseDivGradCUDAKernel<T><<< | ||
| grid_size, block_size, 0, | ||
| ctx.template device_context<plat::CUDADeviceContext>().stream()>>>( | ||
| x->data<T>(), y->data<T>(), out->data<T>(), dout->data<T>(), size, | ||
| nullptr, dy->mutable_data<T>(ctx.GetPlace())); | ||
| } | ||
| } else { | ||
| framework::Tensor tmp_dy; | ||
| tmp_dy.mutable_data<T>(dout->dims(), ctx.GetPlace()); | ||
|
|
||
| std::vector<const framework::Tensor*> ins = {dout, out, y}; | ||
| std::vector<framework::Tensor*> outs = {&tmp_dy}; | ||
|
|
||
| const auto& dev_ctx = | ||
| ctx.template device_context<platform::CUDADeviceContext>(); | ||
| LaunchElementwiseCudaKernel<ElementwiseType::kTernary, T, T>( | ||
| dev_ctx, ins, &outs, axis, DivGradYFunctor<T>()); | ||
|
|
||
| std::vector<int> reduce_dims = GetReduceDim(y->dims(), out->dims(), axis); | ||
| gpuStream_t stream = ctx.cuda_device_context().stream(); | ||
| TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::InverseFunctor<T>>( | ||
| tmp_dy, dy, kps::InverseFunctor<T>(), reduce_dims, stream); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| template <typename DeviceContext, typename T> | ||
| typename std::enable_if< | ||
| std::is_same<DeviceContext, plat::CUDADeviceContext>::value>::type | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -108,6 +108,21 @@ struct DivDoubleDY { | |
| } | ||
| }; | ||
|
|
||
| template <typename DeviceContext, typename T> | ||
| typename std::enable_if< | ||
| std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type | ||
| default_elementwise_div_grad(const framework::ExecutionContext& ctx, | ||
| const framework::Tensor* x, | ||
| const framework::Tensor* y, | ||
| const framework::Tensor* out, | ||
| const framework::Tensor* dout, | ||
| framework::Tensor* dx, framework::Tensor* dy) { | ||
| int axis = ctx.Attr<int>("axis"); | ||
|
|
||
| ElemwiseGradCompute<DeviceContext, T, DivGradDX<T>, DivGradDY<T>>( | ||
| ctx, *x, *y, *out, *dout, axis, dx, dy, DivGradDX<T>(), DivGradDY<T>()); | ||
| } | ||
|
|
||
| template <typename DeviceContext, typename T> | ||
| typename std::enable_if< | ||
| std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type | ||
|
|
@@ -116,13 +131,21 @@ elementwise_div_grad(const framework::ExecutionContext& ctx, | |
| const framework::Tensor* out, | ||
| const framework::Tensor* dout, framework::Tensor* dx, | ||
| framework::Tensor* dy) { | ||
| int axis = ctx.Attr<int>("axis"); | ||
| ElemwiseGradCompute<DeviceContext, T, DivGradDX<T>, DivGradDY<T>>( | ||
| ctx, *x, *y, *out, *dout, axis, dx, dy, DivGradDX<T>(), DivGradDY<T>()); | ||
| default_elementwise_div_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy); | ||
| } | ||
|
|
||
| #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) | ||
| template <typename DeviceContext, typename T> | ||
| // cuda definition | ||
| typename std::enable_if< | ||
| std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type | ||
| default_elementwise_div_grad(const framework::ExecutionContext& ctx, | ||
| const framework::Tensor* x, | ||
| const framework::Tensor* y, | ||
| const framework::Tensor* out, | ||
| const framework::Tensor* dout, | ||
| framework::Tensor* dx, framework::Tensor* dy); | ||
|
|
||
| template <typename DeviceContext, typename T> | ||
| typename std::enable_if< | ||
| std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type | ||
|
|
@@ -146,14 +169,12 @@ class ElementwiseDivGradKernel : public ElemwiseGradKernel<T> { | |
| auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out")); | ||
| auto* dx = ctx.Output<Tensor>(framework::GradVarName("X")); | ||
| auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y")); | ||
| int axis = ctx.Attr<int>("axis"); | ||
|
|
||
| if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) { | ||
|
||
| elementwise_div_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy); | ||
| } else { | ||
| ElemwiseGradCompute<DeviceContext, T, DivGradDX<T>, DivGradDY<T>>( | ||
| ctx, *x, *y, *out, *dout, axis, dx, dy, DivGradDX<T>(), | ||
| DivGradDY<T>()); | ||
| default_elementwise_div_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, | ||
|
||
| dy); | ||
| } | ||
| } | ||
| }; | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
头文件已经删除