Skip to content

Commit 8a5f39f

Browse files
committed
fix set value grad
1 parent 29eaaa1 commit 8a5f39f

File tree

7 files changed

+258
-23
lines changed

7 files changed

+258
-23
lines changed

paddle/fluid/operators/set_value_op.cc

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -151,31 +151,26 @@ class SetValueGradMaker : public framework::SingleGradOpMaker<T> {
151151

152152
protected:
153153
void Apply(GradOpPtr<T> op) const override {
154-
if (this->HasInput("ValueTensor")) {
155-
op->SetType("set_value_grad");
154+
op->SetType("set_value_grad");
155+
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
156+
if (this->HasInput("StartsTensorList")) {
157+
op->SetInput("StartsTensorList", this->Input("StartsTensorList"));
158+
}
159+
if (this->HasInput("EndsTensorList")) {
160+
op->SetInput("EndsTensorList", this->Input("EndsTensorList"));
161+
}
162+
if (this->HasInput("StepsTensorList")) {
163+
op->SetInput("StepsTensorList", this->Input("StepsTensorList"));
164+
}
156165

157-
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
158-
op->SetInput("ValueTensor", this->Input("ValueTensor"));
159-
if (this->HasInput("StartsTensorList")) {
160-
op->SetInput("StartsTensorList", this->Input("StartsTensorList"));
161-
}
162-
if (this->HasInput("EndsTensorList")) {
163-
op->SetInput("EndsTensorList", this->Input("EndsTensorList"));
164-
}
165-
if (this->HasInput("StepsTensorList")) {
166-
op->SetInput("StepsTensorList", this->Input("StepsTensorList"));
167-
}
166+
op->SetAttrMap(this->Attrs());
168167

169-
op->SetAttrMap(this->Attrs());
168+
op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
170169

170+
if (this->HasInput("ValueTensor")) {
171+
op->SetInput("ValueTensor", this->Input("ValueTensor"));
171172
op->SetOutput(framework::GradVarName("ValueTensor"),
172173
this->InputGrad("ValueTensor"));
173-
op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
174-
175-
} else {
176-
op->SetType("assign");
177-
op->SetInput("X", this->OutputGrad("Out"));
178-
op->SetOutput("Out", this->InputGrad("Input"));
179174
}
180175
}
181176
};

paddle/phi/api/yaml/legacy_backward.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -612,14 +612,14 @@
612612

613613
- backward_op : set_value_grad
614614
forward : set_value (Tensor x, IntArray starts, IntArray ends, IntArray steps, int64_t[] axes, int64_t[] decrease_axes, int64_t[] none_axes, int64_t[] shape, Scalar[] values) -> Tensor(out)
615-
args : (Tensor out_grad)
615+
args : (Tensor out_grad, IntArray starts, IntArray ends, IntArray steps, int64_t[] axes, int64_t[] decrease_axes, int64_t[] none_axes)
616616
output : Tensor(x_grad)
617617
infer_meta:
618618
func: UnchangedInferMeta
619619
param: [out_grad]
620620
kernel:
621-
func: assign
622-
param: [out_grad]
621+
func: set_value_with_scalar_grad
622+
param: [out_grad, starts, ends, steps, axes, decrease_axes, none_axes]
623623

624624
- backward_op : set_value_with_tensor_grad
625625
forward: set_value_with_tensor (Tensor x, Tensor values, IntArray starts, IntArray ends, IntArray steps, int64_t[] axes, int64_t[] decrease_axes, int64_t[] none_axes) -> Tensor(out)

paddle/phi/kernels/cpu/set_value_grad_kernel.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,20 @@ PD_REGISTER_KERNEL(set_value_grad,
3535
phi::dtype::float16,
3636
phi::dtype::complex<float>,
3737
phi::dtype::complex<double>) {}
38+
39+
PD_REGISTER_KERNEL(set_value_with_scalar_grad,
40+
CPU,
41+
ALL_LAYOUT,
42+
phi::SetValueWithScalarGradKernel,
43+
float,
44+
double,
45+
int,
46+
int64_t,
47+
bool,
48+
int16_t,
49+
uint8_t,
50+
int8_t,
51+
phi::dtype::bfloat16,
52+
phi::dtype::float16,
53+
phi::dtype::complex<float>,
54+
phi::dtype::complex<double>) {}

paddle/phi/kernels/gpu/set_value_grad_kernel.cu

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,20 @@ PD_REGISTER_KERNEL(set_value_grad,
3535
phi::dtype::bfloat16,
3636
phi::dtype::complex<float>,
3737
phi::dtype::complex<double>) {}
38+
39+
PD_REGISTER_KERNEL(set_value_with_scalar_grad,
40+
GPU,
41+
ALL_LAYOUT,
42+
phi::SetValueWithScalarGradKernel,
43+
float,
44+
double,
45+
int,
46+
int64_t,
47+
bool,
48+
int16_t,
49+
uint8_t,
50+
int8_t,
51+
phi::dtype::float16,
52+
phi::dtype::bfloat16,
53+
phi::dtype::complex<float>,
54+
phi::dtype::complex<double>) {}

paddle/phi/kernels/impl/set_value_grad_kernel_impl.h

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,4 +341,97 @@ void SetValueGradKernel(const Context& dev_ctx,
341341
}
342342
}
343343

344+
template <typename T, typename Context>
345+
void SetValueWithScalarGradKernel(const Context& dev_ctx,
346+
const DenseTensor& out_grad,
347+
const IntArray& starts,
348+
const IntArray& ends,
349+
const IntArray& steps,
350+
const std::vector<int64_t>& axes,
351+
const std::vector<int64_t>& decrease_axes,
352+
const std::vector<int64_t>& none_axes,
353+
DenseTensor* x_grad) {
354+
const int rank = out_grad.dims().size();
355+
356+
switch (rank) {
357+
case 1:
358+
SetValueGradImpl<T, Context, 1>(dev_ctx,
359+
out_grad,
360+
starts,
361+
ends,
362+
steps,
363+
axes,
364+
decrease_axes,
365+
none_axes,
366+
x_grad,
367+
nullptr);
368+
break;
369+
case 2:
370+
SetValueGradImpl<T, Context, 2>(dev_ctx,
371+
out_grad,
372+
starts,
373+
ends,
374+
steps,
375+
axes,
376+
decrease_axes,
377+
none_axes,
378+
x_grad,
379+
nullptr);
380+
break;
381+
case 3:
382+
SetValueGradImpl<T, Context, 3>(dev_ctx,
383+
out_grad,
384+
starts,
385+
ends,
386+
steps,
387+
axes,
388+
decrease_axes,
389+
none_axes,
390+
x_grad,
391+
nullptr);
392+
break;
393+
case 4:
394+
SetValueGradImpl<T, Context, 4>(dev_ctx,
395+
out_grad,
396+
starts,
397+
ends,
398+
steps,
399+
axes,
400+
decrease_axes,
401+
none_axes,
402+
x_grad,
403+
nullptr);
404+
break;
405+
case 5:
406+
SetValueGradImpl<T, Context, 5>(dev_ctx,
407+
out_grad,
408+
starts,
409+
ends,
410+
steps,
411+
axes,
412+
decrease_axes,
413+
none_axes,
414+
x_grad,
415+
nullptr);
416+
break;
417+
case 6:
418+
SetValueGradImpl<T, Context, 6>(dev_ctx,
419+
out_grad,
420+
starts,
421+
ends,
422+
steps,
423+
axes,
424+
decrease_axes,
425+
none_axes,
426+
x_grad,
427+
nullptr);
428+
break;
429+
default:
430+
PADDLE_THROW(phi::errors::InvalidArgument(
431+
"The rank of set_value_with_scalar_grad's input should be less than "
432+
"7, but "
433+
"received %d.",
434+
rank));
435+
}
436+
}
344437
} // namespace phi

paddle/phi/kernels/set_value_grad_kernel.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,14 @@ void SetValueGradKernel(const Context& dev_ctx,
3232
DenseTensor* x_grad,
3333
DenseTensor* value_grad);
3434

35+
template <typename T, typename Context>
36+
void SetValueWithScalarGradKernel(const Context& dev_ctx,
37+
const DenseTensor& out_grad,
38+
const IntArray& starts,
39+
const IntArray& ends,
40+
const IntArray& steps,
41+
const std::vector<int64_t>& axes,
42+
const std::vector<int64_t>& decrease_axes,
43+
const std::vector<int64_t>& none_axes,
44+
DenseTensor* x_grad);
3545
} // namespace phi

paddle/phi/kernels/xpu/set_value_grad_kernel.cc

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,100 @@ void SetValueGradKernel(const Context& dev_ctx,
397397
}
398398
}
399399

400+
template <typename T, typename Context>
401+
void SetValueWithScalarGradKernel(const Context& dev_ctx,
402+
const DenseTensor& out_grad,
403+
const IntArray& starts,
404+
const IntArray& ends,
405+
const IntArray& steps,
406+
const std::vector<int64_t>& axes,
407+
const std::vector<int64_t>& decrease_axes,
408+
const std::vector<int64_t>& none_axes,
409+
DenseTensor* x_grad) {
410+
const int rank = out_grad.dims().size();
411+
412+
switch (rank) {
413+
case 1:
414+
SetValueGradImpl<T, Context, 1>(dev_ctx,
415+
out_grad,
416+
starts,
417+
ends,
418+
steps,
419+
axes,
420+
decrease_axes,
421+
none_axes,
422+
x_grad,
423+
nullptr);
424+
break;
425+
case 2:
426+
SetValueGradImpl<T, Context, 2>(dev_ctx,
427+
out_grad,
428+
starts,
429+
ends,
430+
steps,
431+
axes,
432+
decrease_axes,
433+
none_axes,
434+
x_grad,
435+
nullptr);
436+
break;
437+
case 3:
438+
SetValueGradImpl<T, Context, 3>(dev_ctx,
439+
out_grad,
440+
starts,
441+
ends,
442+
steps,
443+
axes,
444+
decrease_axes,
445+
none_axes,
446+
x_grad,
447+
nullptr);
448+
break;
449+
case 4:
450+
SetValueGradImpl<T, Context, 4>(dev_ctx,
451+
out_grad,
452+
starts,
453+
ends,
454+
steps,
455+
axes,
456+
decrease_axes,
457+
none_axes,
458+
x_grad,
459+
nullptr);
460+
break;
461+
case 5:
462+
SetValueGradImpl<T, Context, 5>(dev_ctx,
463+
out_grad,
464+
starts,
465+
ends,
466+
steps,
467+
axes,
468+
decrease_axes,
469+
none_axes,
470+
x_grad,
471+
nullptr);
472+
break;
473+
case 6:
474+
SetValueGradImpl<T, Context, 6>(dev_ctx,
475+
out_grad,
476+
starts,
477+
ends,
478+
steps,
479+
axes,
480+
decrease_axes,
481+
none_axes,
482+
x_grad,
483+
nullptr);
484+
break;
485+
default:
486+
PADDLE_THROW(phi::errors::InvalidArgument(
487+
"The rank of set_value_with_scalar_grad's input should be less than "
488+
"7, but "
489+
"received %d.",
490+
rank));
491+
}
492+
}
493+
400494
} // namespace phi
401495

402496
PD_REGISTER_KERNEL(set_value_grad,
@@ -407,3 +501,12 @@ PD_REGISTER_KERNEL(set_value_grad,
407501
phi::dtype::float16,
408502
int,
409503
int64_t) {}
504+
505+
PD_REGISTER_KERNEL(set_value_with_scalar_grad,
506+
XPU,
507+
ALL_LAYOUT,
508+
phi::SetValueWithScalarGradKernel,
509+
float,
510+
phi::dtype::float16,
511+
int,
512+
int64_t) {}

0 commit comments

Comments
 (0)