@@ -127,6 +127,9 @@ static void StridedSliceFunctor(int64_t* starts, int64_t* ends,
127127 if (!(ends[axis_index] == -1 &&
128128 strides[axis_index] < 0 )) { // skip None stop condition
129129 ends[axis_index] = ends[axis_index] + axis_size;
130+ if (ends[axis_index] < 0 ) {
131+ ends[axis_index] = 0 ;
132+ }
130133 }
131134 }
132135 if (decrease_axis_affect) {
@@ -147,9 +150,8 @@ static void StridedSliceFunctor(int64_t* starts, int64_t* ends,
147150 strides[axis_index] = -strides[axis_index];
148151 if (starts[axis_index] > ends[axis_index]) {
149152 // swap the reverse
150- auto end_dim = dims[axis_index] - 1 < starts[axis_index]
151- ? dims[axis_index] - 1
152- : starts[axis_index];
153+ auto end_dim = axis_size - 1 < starts[axis_index] ? axis_size - 1
154+ : starts[axis_index];
153155 auto offset = (end_dim - ends[axis_index]) % strides[axis_index];
154156 offset = offset == 0 ? strides[axis_index] : offset;
155157
@@ -378,33 +380,32 @@ class StridedSliceKernel : public framework::OpKernel<T> {
378380 TensorCopy (in_tensor, context.GetPlace (), out_tensor);
379381 }
380382
381- return ;
382- }
383- auto in = context.Input <framework::Tensor>(" Input" );
384- auto out = context.Output <framework::Tensor>(" Out" );
385- out->Resize (out_dims);
386- out->mutable_data <T>(context.GetPlace ());
387- auto in_t =
388- framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From (
389- *in);
390- auto out_t =
391- framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From (
392- *out, out_dims);
393- if (need_reverse) {
394- framework::Tensor tmp;
395- tmp.mutable_data <T>(out_dims, context.GetPlace ());
396- auto tmp_t = framework::EigenTensor<T, D, Eigen::RowMajor,
397- Eigen::DenseIndex>::From (tmp);
398- tmp_t .device (place) =
399- in_t .stridedSlice (starts_indices, ends_indices, strides_indices);
400- out_t .device (place) = tmp_t .reverse (reverse_axis);
401383 } else {
402- out_t .device (place) =
403- in_t .stridedSlice (starts_indices, ends_indices, strides_indices);
404- }
384+ auto in = context.Input <framework::Tensor>(" Input" );
385+ auto out = context.Output <framework::Tensor>(" Out" );
386+ out->Resize (out_dims);
387+ out->mutable_data <T>(context.GetPlace ());
388+ auto in_t = framework::EigenTensor<T, D, Eigen::RowMajor,
389+ Eigen::DenseIndex>::From (*in);
390+ auto out_t =
391+ framework::EigenTensor<T, D, Eigen::RowMajor,
392+ Eigen::DenseIndex>::From (*out, out_dims);
393+ if (need_reverse) {
394+ framework::Tensor tmp;
395+ tmp.mutable_data <T>(out_dims, context.GetPlace ());
396+ auto tmp_t = framework::EigenTensor<T, D, Eigen::RowMajor,
397+ Eigen::DenseIndex>::From (tmp);
398+ tmp_t .device (place) =
399+ in_t .stridedSlice (starts_indices, ends_indices, strides_indices);
400+ out_t .device (place) = tmp_t .reverse (reverse_axis);
401+ } else {
402+ out_t .device (place) =
403+ in_t .stridedSlice (starts_indices, ends_indices, strides_indices);
404+ }
405405
406- if (decrease_axis.size () > 0 ) {
407- out->Resize (out_dims_origin);
406+ if (decrease_axis.size () > 0 ) {
407+ out->Resize (out_dims_origin);
408+ }
408409 }
409410 }
410411};
@@ -453,11 +454,11 @@ class StridedSliceGradKernel : public framework::OpKernel<T> {
453454 auto * out_var = context.OutputVar (framework::GradVarName (" Input" ));
454455 bool is_out_var_array = out_var->IsType <LoDTensorArray>();
455456 if (is_out_var_array) {
456- // Since the shape of `framework::GradVarName("Input")` of
457- // StridedSliceGrad
458- // cannot be calculated by `framework::GradVarName("Output")`,
459- // the dim of "Input" is used to calculate the output shape.
460- // when set it to inplace OP, there may be some problems.
457+ // Note(weixin): Since the shape of `framework::GradVarName("Input")` of
458+ // StridedSliceGrad cannot be calculated by
459+ // `framework::GradVarName("Output")`, the dim of "Input" is used to
460+ // calculate the output shape. when set it to inplace OP, there may be
461+ // some problems.
461462 const int64_t size =
462463 context.Input <framework::LoDTensorArray>(" Input" )->size ();
463464
@@ -621,40 +622,39 @@ class StridedSliceGradKernel : public framework::OpKernel<T> {
621622 set_zero (dev_ctx, d_out_tensor, static_cast <T>(0 ));
622623 }
623624 }
624- return ;
625- }
626625
627- auto * d_input =
628- context.Input <framework::Tensor>(framework::GradVarName (" Out" ));
629- auto * d_out =
630- context.Output <framework::Tensor>(framework::GradVarName (" Input" ));
626+ } else {
627+ auto * d_input =
628+ context.Input <framework::Tensor>(framework::GradVarName (" Out" ));
629+ auto * d_out =
630+ context.Output <framework::Tensor>(framework::GradVarName (" Input" ));
631631
632- d_out->mutable_data <T>(context.GetPlace ());
632+ d_out->mutable_data <T>(context.GetPlace ());
633633
634- math::SetConstant<DeviceContext, T> set_zero;
635- set_zero (dev_ctx, d_out, static_cast <T>(0 ));
634+ math::SetConstant<DeviceContext, T> set_zero;
635+ set_zero (dev_ctx, d_out, static_cast <T>(0 ));
636636
637- auto in_dims = d_input->dims ();
637+ auto in_dims = d_input->dims ();
638638
639- auto in_t =
640- framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From (
641- *d_input);
642- auto out_t =
643- framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From (
644- *d_out, out_dims);
645- if (need_reverse) {
646- framework::Tensor reverse_input;
647- reverse_input.mutable_data <T>(in_dims, context.GetPlace ());
648- auto reverse_in_t =
639+ auto in_t = framework::EigenTensor<T, D, Eigen::RowMajor,
640+ Eigen::DenseIndex>::From (*d_input);
641+ auto out_t =
649642 framework::EigenTensor<T, D, Eigen::RowMajor,
650- Eigen::DenseIndex>::From (reverse_input);
651-
652- reverse_in_t .device (place) = in_t .reverse (reverse_axis);
653- out_t .stridedSlice (starts_indices, ends_indices, strides_indices)
654- .device (place) = reverse_in_t ;
655- } else {
656- out_t .stridedSlice (starts_indices, ends_indices, strides_indices)
657- .device (place) = in_t ;
643+ Eigen::DenseIndex>::From (*d_out, out_dims);
644+ if (need_reverse) {
645+ framework::Tensor reverse_input;
646+ reverse_input.mutable_data <T>(in_dims, context.GetPlace ());
647+ auto reverse_in_t =
648+ framework::EigenTensor<T, D, Eigen::RowMajor,
649+ Eigen::DenseIndex>::From (reverse_input);
650+
651+ reverse_in_t .device (place) = in_t .reverse (reverse_axis);
652+ out_t .stridedSlice (starts_indices, ends_indices, strides_indices)
653+ .device (place) = reverse_in_t ;
654+ } else {
655+ out_t .stridedSlice (starts_indices, ends_indices, strides_indices)
656+ .device (place) = in_t ;
657+ }
658658 }
659659 }
660660};
0 commit comments