@@ -151,31 +151,26 @@ class SetValueGradMaker : public framework::SingleGradOpMaker<T> {
151151
152152 protected:
153153 void Apply (GradOpPtr<T> op) const override {
154- if (this ->HasInput (" ValueTensor" )) {
155- op->SetType (" set_value_grad" );
154+ op->SetType (" set_value_grad" );
155+ op->SetInput (framework::GradVarName (" Out" ), this ->OutputGrad (" Out" ));
156+ if (this ->HasInput (" StartsTensorList" )) {
157+ op->SetInput (" StartsTensorList" , this ->Input (" StartsTensorList" ));
158+ }
159+ if (this ->HasInput (" EndsTensorList" )) {
160+ op->SetInput (" EndsTensorList" , this ->Input (" EndsTensorList" ));
161+ }
162+ if (this ->HasInput (" StepsTensorList" )) {
163+ op->SetInput (" StepsTensorList" , this ->Input (" StepsTensorList" ));
164+ }
156165
157- op->SetInput (framework::GradVarName (" Out" ), this ->OutputGrad (" Out" ));
158- op->SetInput (" ValueTensor" , this ->Input (" ValueTensor" ));
159- if (this ->HasInput (" StartsTensorList" )) {
160- op->SetInput (" StartsTensorList" , this ->Input (" StartsTensorList" ));
161- }
162- if (this ->HasInput (" EndsTensorList" )) {
163- op->SetInput (" EndsTensorList" , this ->Input (" EndsTensorList" ));
164- }
165- if (this ->HasInput (" StepsTensorList" )) {
166- op->SetInput (" StepsTensorList" , this ->Input (" StepsTensorList" ));
167- }
166+ op->SetAttrMap (this ->Attrs ());
168167
169- op->SetAttrMap ( this ->Attrs ( ));
168+ op->SetOutput ( framework::GradVarName ( " Input " ), this ->InputGrad ( " Input " ));
170169
170+ if (this ->HasInput (" ValueTensor" )) {
171+ op->SetInput (" ValueTensor" , this ->Input (" ValueTensor" ));
171172 op->SetOutput (framework::GradVarName (" ValueTensor" ),
172173 this ->InputGrad (" ValueTensor" ));
173- op->SetOutput (framework::GradVarName (" Input" ), this ->InputGrad (" Input" ));
174-
175- } else {
176- op->SetType (" assign" );
177- op->SetInput (" X" , this ->OutputGrad (" Out" ));
178- op->SetOutput (" Out" , this ->InputGrad (" Input" ));
179174 }
180175 }
181176};
0 commit comments