File tree Expand file tree Collapse file tree 1 file changed +9
-5
lines changed Expand file tree Collapse file tree 1 file changed +9
-5
lines changed Original file line number Diff line number Diff line change @@ -1395,13 +1395,20 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
13951395 Tensor* logit_grad =
13961396 context.Output <Tensor>(framework::GradVarName (" Logits" ));
13971397 const Tensor* softmax = context.Input <Tensor>(" Softmax" );
1398+ auto stream = context.cuda_device_context ().stream ();
1399+ auto ignore_index = context.Attr <int >(" ignore_index" );
1400+ auto use_softmax = context.Attr <bool >(" use_softmax" );
1401+
1402+ T* logit_grad_data = nullptr ;
13981403 bool copy_flag = (logit_grad != softmax && (!use_softmax || soft_label));
13991404 if (copy_flag) {
14001405 framework::TensorCopy (*softmax, context.GetPlace (),
14011406 context.device_context (), logit_grad);
1407+ logit_grad_data = logit_grad->template data <T>();
1408+ } else {
1409+ logit_grad_data =
1410+ logit_grad->template mutable_data <T>(context.GetPlace ());
14021411 }
1403- T* logit_grad_data =
1404- logit_grad->template mutable_data <T>(context.GetPlace ());
14051412
14061413 const int rank = logit_grad->dims ().size ();
14071414 const int axis = phi::funcs::CanonicalAxis (context.Attr <int >(" axis" ), rank);
@@ -1416,9 +1423,6 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
14161423#else
14171424 int block = 512 ;
14181425#endif
1419- auto stream = context.cuda_device_context ().stream ();
1420- auto ignore_index = context.Attr <int >(" ignore_index" );
1421- auto use_softmax = context.Attr <bool >(" use_softmax" );
14221426
14231427 // do not with softmax op, and input is softmax
14241428 if (!use_softmax) {
You can’t perform that action at this time.
0 commit comments