Skip to content

Commit 4233d37

Browse files
committed
Fix
1 parent 8ce15bf commit 4233d37

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

paddle/phi/ops/yaml/inconsistent/dygraph_backward.yaml

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,16 @@
9494
func : c_embedding_grad
9595
no_need_buffer : weight
9696

97+
- backward_op : c_softmax_with_cross_entropy_grad
98+
forward: c_softmax_with_cross_entropy (Tensor logits, Tensor label, int64_t ignore_index=-100, int ring_id=0, int rank=0, int nranks=0) -> Tensor(softmax), Tensor(loss)
99+
args: (Tensor softmax, Tensor label, Tensor loss_grad,int64_t ignore_index=-100, int ring_id=0, int rank=0, int nranks=0)
100+
output: Tensor(logits_grad)
101+
infer_meta :
102+
func: CSoftmaxWithCrossEntropyGradInferMeta
103+
kernel:
104+
func: c_softmax_with_cross_entropy_grad
105+
data_type: loss_grad
106+
97107
- backward_op : divide_double_grad
98108
forward : divide_grad (Tensor x, Tensor y, Tensor out, Tensor grad_out, int axis = -1) -> Tensor(grad_x), Tensor(grad_y)
99109
args : (Tensor y, Tensor out, Tensor grad_out, Tensor grad_x, Tensor grad_x_grad, Tensor grad_y_grad, int axis = -1)
@@ -277,16 +287,6 @@
277287
func: set_value_with_scalar_grad
278288
param: [out_grad, starts, ends, steps, axes, decrease_axes, none_axes]
279289

280-
- backward_op : c_softmax_with_cross_entropy_grad
281-
forward: c_softmax_with_cross_entropy (Tensor logits, Tensor label, int64_t ignore_index=-100, int ring_id=0, int rank=0, int nranks=0) -> Tensor(softmax), Tensor(loss)
282-
args: (Tensor softmax, Tensor label, Tensor loss_grad,int64_t ignore_index=-100, int ring_id=0, int rank=0, int nranks=0)
283-
output: Tensor(logits_grad)
284-
infer_meta :
285-
func: CSoftmaxWithCrossEntropyGradInferMeta
286-
kernel:
287-
func: c_softmax_with_cross_entropy_grad
288-
data_type: loss_grad
289-
290290
- backward_op : softmax_grad
291291
forward : softmax (Tensor x, int axis) -> Tensor(out)
292292
args : (Tensor out, Tensor out_grad, int axis)

0 commit comments

Comments
 (0)