Skip to content

Commit 081e430

Browse files
authored
Optimize perf of softmax_with_cross_entropy_bwd (#40643)
* Optimize perf of softmax_with_cross_entropy_bwd * fix * fix
1 parent 1904572 commit 081e430

File tree

1 file changed

+19
-12
lines changed

1 file changed

+19
-12
lines changed

paddle/fluid/operators/softmax_with_cross_entropy_op.cu

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -760,8 +760,9 @@ static void SoftmaxWithCrossEntropyHardLabel(
760760
*/
761761
template <typename T, typename LabelT>
762762
__global__ void SoftmaxWithCrossEntropyGradHardLabel(
763-
T* logits_grad, const T* loss_grad, const LabelT* labels, const int64_t n,
764-
const int64_t dim, const int64_t d, const int ignore_index) {
763+
T* logits_grad, const T* loss_grad, const T* softmax, const LabelT* labels,
764+
const int64_t n, const int64_t dim, const int64_t d,
765+
const int ignore_index) {
765766
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
766767
int64_t idx_n = idx / (d * dim);
767768
int64_t idx_dim = (idx / d) % dim;
@@ -773,10 +774,9 @@ __global__ void SoftmaxWithCrossEntropyGradHardLabel(
773774
if (lbl == ignore_index) {
774775
logits_grad[idx] = static_cast<T>(0.0);
775776
} else if (lbl == idx_dim) {
776-
logits_grad[idx] =
777-
(logits_grad[idx] - static_cast<T>(1.0)) * loss_grad[ids];
777+
logits_grad[idx] = (softmax[idx] - static_cast<T>(1.0)) * loss_grad[ids];
778778
} else {
779-
logits_grad[idx] *= loss_grad[ids];
779+
logits_grad[idx] = softmax[idx] * loss_grad[ids];
780780
}
781781
}
782782
}
@@ -1395,11 +1395,20 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
13951395
Tensor* logit_grad =
13961396
context.Output<Tensor>(framework::GradVarName("Logits"));
13971397
const Tensor* softmax = context.Input<Tensor>("Softmax");
1398-
if (logit_grad != softmax) {
1398+
auto stream = context.cuda_device_context().stream();
1399+
auto ignore_index = context.Attr<int>("ignore_index");
1400+
auto use_softmax = context.Attr<bool>("use_softmax");
1401+
1402+
T* logit_grad_data = nullptr;
1403+
bool copy_flag = (logit_grad != softmax && (!use_softmax || soft_label));
1404+
if (copy_flag) {
13991405
framework::TensorCopy(*softmax, context.GetPlace(),
14001406
context.device_context(), logit_grad);
1407+
logit_grad_data = logit_grad->template data<T>();
1408+
} else {
1409+
logit_grad_data =
1410+
logit_grad->template mutable_data<T>(context.GetPlace());
14011411
}
1402-
T* logit_grad_data = logit_grad->template data<T>();
14031412

14041413
const int rank = logit_grad->dims().size();
14051414
const int axis = phi::funcs::CanonicalAxis(context.Attr<int>("axis"), rank);
@@ -1414,9 +1423,6 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
14141423
#else
14151424
int block = 512;
14161425
#endif
1417-
auto stream = context.cuda_device_context().stream();
1418-
auto ignore_index = context.Attr<int>("ignore_index");
1419-
auto use_softmax = context.Attr<bool>("use_softmax");
14201426

14211427
// do not with softmax op, and input is softmax
14221428
if (!use_softmax) {
@@ -1451,11 +1457,12 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
14511457
SoftCrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>(
14521458
logit_grad_data, loss_grad_data, label_data, n, d, remain);
14531459
} else {
1460+
const T* softmax_data = softmax->template data<T>();
14541461
const auto* label_data = labels.template data<LabelT>();
14551462
int grid = (n * d + block - 1) / block;
14561463
SoftmaxWithCrossEntropyGradHardLabel<T><<<grid, block, 0, stream>>>(
1457-
logit_grad_data, loss_grad_data, label_data, n, d / remain, remain,
1458-
ignore_index);
1464+
logit_grad_data, loss_grad_data, softmax_data, label_data, n,
1465+
d / remain, remain, ignore_index);
14591466
}
14601467
}
14611468
};

0 commit comments

Comments
 (0)