Skip to content

Commit 2887f0b

Browse files
committed
fix
1 parent cdf5bd6 commit 2887f0b

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

paddle/fluid/operators/softmax_with_cross_entropy_op.cu

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1395,13 +1395,20 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
13951395
Tensor* logit_grad =
13961396
context.Output<Tensor>(framework::GradVarName("Logits"));
13971397
const Tensor* softmax = context.Input<Tensor>("Softmax");
1398+
auto stream = context.cuda_device_context().stream();
1399+
auto ignore_index = context.Attr<int>("ignore_index");
1400+
auto use_softmax = context.Attr<bool>("use_softmax");
1401+
1402+
T* logit_grad_data = nullptr;
13981403
bool copy_flag = (logit_grad != softmax && (!use_softmax || soft_label));
13991404
if (copy_flag) {
14001405
framework::TensorCopy(*softmax, context.GetPlace(),
14011406
context.device_context(), logit_grad);
1407+
logit_grad_data = logit_grad->template data<T>();
1408+
} else {
1409+
logit_grad_data =
1410+
logit_grad->template mutable_data<T>(context.GetPlace());
14021411
}
1403-
T* logit_grad_data =
1404-
logit_grad->template mutable_data<T>(context.GetPlace());
14051412

14061413
const int rank = logit_grad->dims().size();
14071414
const int axis = phi::funcs::CanonicalAxis(context.Attr<int>("axis"), rank);
@@ -1416,9 +1423,6 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
14161423
#else
14171424
int block = 512;
14181425
#endif
1419-
auto stream = context.cuda_device_context().stream();
1420-
auto ignore_index = context.Attr<int>("ignore_index");
1421-
auto use_softmax = context.Attr<bool>("use_softmax");
14221426

14231427
// do not with softmax op, and input is softmax
14241428
if (!use_softmax) {

0 commit comments

Comments
 (0)