@@ -760,8 +760,9 @@ static void SoftmaxWithCrossEntropyHardLabel(
760760*/
761761template <typename T, typename LabelT>
762762__global__ void SoftmaxWithCrossEntropyGradHardLabel (
763- T* logits_grad, const T* loss_grad, const LabelT* labels, const int64_t n,
764- const int64_t dim, const int64_t d, const int ignore_index) {
763+ T* logits_grad, const T* loss_grad, const T* softmax, const LabelT* labels,
764+ const int64_t n, const int64_t dim, const int64_t d,
765+ const int ignore_index) {
765766 int64_t idx = blockIdx .x * blockDim .x + threadIdx .x ;
766767 int64_t idx_n = idx / (d * dim);
767768 int64_t idx_dim = (idx / d) % dim;
@@ -773,10 +774,9 @@ __global__ void SoftmaxWithCrossEntropyGradHardLabel(
773774 if (lbl == ignore_index) {
774775 logits_grad[idx] = static_cast <T>(0.0 );
775776 } else if (lbl == idx_dim) {
776- logits_grad[idx] =
777- (logits_grad[idx] - static_cast <T>(1.0 )) * loss_grad[ids];
777+ logits_grad[idx] = (softmax[idx] - static_cast <T>(1.0 )) * loss_grad[ids];
778778 } else {
779- logits_grad[idx] *= loss_grad[ids];
779+ logits_grad[idx] = softmax[idx] * loss_grad[ids];
780780 }
781781 }
782782}
@@ -1395,11 +1395,20 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
13951395 Tensor* logit_grad =
13961396 context.Output <Tensor>(framework::GradVarName (" Logits" ));
13971397 const Tensor* softmax = context.Input <Tensor>(" Softmax" );
1398- if (logit_grad != softmax) {
1398+ auto stream = context.cuda_device_context ().stream ();
1399+ auto ignore_index = context.Attr <int >(" ignore_index" );
1400+ auto use_softmax = context.Attr <bool >(" use_softmax" );
1401+
1402+ T* logit_grad_data = nullptr ;
1403+ bool copy_flag = (logit_grad != softmax && (!use_softmax || soft_label));
1404+ if (copy_flag) {
13991405 framework::TensorCopy (*softmax, context.GetPlace (),
14001406 context.device_context (), logit_grad);
1407+ logit_grad_data = logit_grad->template data <T>();
1408+ } else {
1409+ logit_grad_data =
1410+ logit_grad->template mutable_data <T>(context.GetPlace ());
14011411 }
1402- T* logit_grad_data = logit_grad->template data <T>();
14031412
14041413 const int rank = logit_grad->dims ().size ();
14051414 const int axis = phi::funcs::CanonicalAxis (context.Attr <int >(" axis" ), rank);
@@ -1414,9 +1423,6 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
14141423#else
14151424 int block = 512 ;
14161425#endif
1417- auto stream = context.cuda_device_context ().stream ();
1418- auto ignore_index = context.Attr <int >(" ignore_index" );
1419- auto use_softmax = context.Attr <bool >(" use_softmax" );
14201426
14211427 // do not with softmax op, and input is softmax
14221428 if (!use_softmax) {
@@ -1451,11 +1457,12 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
14511457 SoftCrossEntropyGradientKernel<T><<<grid, block, 0 , stream>>> (
14521458 logit_grad_data, loss_grad_data, label_data, n, d, remain);
14531459 } else {
1460+ const T* softmax_data = softmax->template data <T>();
14541461 const auto * label_data = labels.template data <LabelT>();
14551462 int grid = (n * d + block - 1 ) / block;
14561463 SoftmaxWithCrossEntropyGradHardLabel<T><<<grid, block, 0 , stream>>> (
1457- logit_grad_data, loss_grad_data, label_data, n, d / remain, remain ,
1458- ignore_index);
1464+ logit_grad_data, loss_grad_data, softmax_data, label_data, n ,
1465+ d / remain, remain, ignore_index);
14591466 }
14601467 }
14611468};
0 commit comments