Skip to content

Commit 3b6ad37

Browse files
committed
polish code
1 parent 39aa575 commit 3b6ad37

File tree

4 files changed

+148
-84
lines changed

4 files changed

+148
-84
lines changed

paddle/fluid/operators/strided_slice_op.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,19 @@ class StridedSliceOp : public framework::OperatorWithKernel {
163163
auto *in_var = ctx.InputVar("Input");
164164
auto is_in_var_array = in_var->IsType<framework::LoDTensorArray>();
165165
if (is_in_var_array) {
166+
auto &tensor_array = in_var->Get<framework::LoDTensorArray>();
167+
for (auto &tensor : tensor_array) {
168+
if (!platform::is_cuda_pinned_place(tensor.place())) {
169+
PADDLE_ENFORCE_EQ(
170+
platform::is_same_place(tensor.place(),
171+
ctx.device_context().GetPlace()),
172+
true, platform::errors::InvalidArgument(
173+
"Place of context is %s. Place of context is %s. They "
174+
"are should be same, but reveived different place.",
175+
string::to_string(ctx.device_context().GetPlace()),
176+
string::to_string(tensor.place())));
177+
}
178+
}
166179
return framework::OpKernelType(
167180
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
168181
ctx.device_context());

paddle/fluid/operators/strided_slice_op.h

Lines changed: 61 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,9 @@ static void StridedSliceFunctor(int64_t* starts, int64_t* ends,
127127
if (!(ends[axis_index] == -1 &&
128128
strides[axis_index] < 0)) { // skip None stop condition
129129
ends[axis_index] = ends[axis_index] + axis_size;
130+
if (ends[axis_index] < 0) {
131+
ends[axis_index] = 0;
132+
}
130133
}
131134
}
132135
if (decrease_axis_affect) {
@@ -147,9 +150,8 @@ static void StridedSliceFunctor(int64_t* starts, int64_t* ends,
147150
strides[axis_index] = -strides[axis_index];
148151
if (starts[axis_index] > ends[axis_index]) {
149152
// swap the reverse
150-
auto end_dim = dims[axis_index] - 1 < starts[axis_index]
151-
? dims[axis_index] - 1
152-
: starts[axis_index];
153+
auto end_dim = axis_size - 1 < starts[axis_index] ? axis_size - 1
154+
: starts[axis_index];
153155
auto offset = (end_dim - ends[axis_index]) % strides[axis_index];
154156
offset = offset == 0 ? strides[axis_index] : offset;
155157

@@ -378,33 +380,32 @@ class StridedSliceKernel : public framework::OpKernel<T> {
378380
TensorCopy(in_tensor, context.GetPlace(), out_tensor);
379381
}
380382

381-
return;
382-
}
383-
auto in = context.Input<framework::Tensor>("Input");
384-
auto out = context.Output<framework::Tensor>("Out");
385-
out->Resize(out_dims);
386-
out->mutable_data<T>(context.GetPlace());
387-
auto in_t =
388-
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
389-
*in);
390-
auto out_t =
391-
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
392-
*out, out_dims);
393-
if (need_reverse) {
394-
framework::Tensor tmp;
395-
tmp.mutable_data<T>(out_dims, context.GetPlace());
396-
auto tmp_t = framework::EigenTensor<T, D, Eigen::RowMajor,
397-
Eigen::DenseIndex>::From(tmp);
398-
tmp_t.device(place) =
399-
in_t.stridedSlice(starts_indices, ends_indices, strides_indices);
400-
out_t.device(place) = tmp_t.reverse(reverse_axis);
401383
} else {
402-
out_t.device(place) =
403-
in_t.stridedSlice(starts_indices, ends_indices, strides_indices);
404-
}
384+
auto in = context.Input<framework::Tensor>("Input");
385+
auto out = context.Output<framework::Tensor>("Out");
386+
out->Resize(out_dims);
387+
out->mutable_data<T>(context.GetPlace());
388+
auto in_t = framework::EigenTensor<T, D, Eigen::RowMajor,
389+
Eigen::DenseIndex>::From(*in);
390+
auto out_t =
391+
framework::EigenTensor<T, D, Eigen::RowMajor,
392+
Eigen::DenseIndex>::From(*out, out_dims);
393+
if (need_reverse) {
394+
framework::Tensor tmp;
395+
tmp.mutable_data<T>(out_dims, context.GetPlace());
396+
auto tmp_t = framework::EigenTensor<T, D, Eigen::RowMajor,
397+
Eigen::DenseIndex>::From(tmp);
398+
tmp_t.device(place) =
399+
in_t.stridedSlice(starts_indices, ends_indices, strides_indices);
400+
out_t.device(place) = tmp_t.reverse(reverse_axis);
401+
} else {
402+
out_t.device(place) =
403+
in_t.stridedSlice(starts_indices, ends_indices, strides_indices);
404+
}
405405

406-
if (decrease_axis.size() > 0) {
407-
out->Resize(out_dims_origin);
406+
if (decrease_axis.size() > 0) {
407+
out->Resize(out_dims_origin);
408+
}
408409
}
409410
}
410411
};
@@ -453,11 +454,11 @@ class StridedSliceGradKernel : public framework::OpKernel<T> {
453454
auto* out_var = context.OutputVar(framework::GradVarName("Input"));
454455
bool is_out_var_array = out_var->IsType<LoDTensorArray>();
455456
if (is_out_var_array) {
456-
// Since the shape of `framework::GradVarName("Input")` of
457-
// StridedSliceGrad
458-
// cannot be calculated by `framework::GradVarName("Output")`,
459-
// the dim of "Input" is used to calculate the output shape.
460-
// when set it to inplace OP, there may be some problems.
457+
// Note(weixin):Since the shape of `framework::GradVarName("Input")` of
458+
// StridedSliceGrad cannot be calculated by
459+
// `framework::GradVarName("Output")`, the dim of "Input" is used to
460+
// calculate the output shape. when set it to inplace OP, there may be
461+
// some problems.
461462
const int64_t size =
462463
context.Input<framework::LoDTensorArray>("Input")->size();
463464

@@ -621,40 +622,39 @@ class StridedSliceGradKernel : public framework::OpKernel<T> {
621622
set_zero(dev_ctx, d_out_tensor, static_cast<T>(0));
622623
}
623624
}
624-
return;
625-
}
626625

627-
auto* d_input =
628-
context.Input<framework::Tensor>(framework::GradVarName("Out"));
629-
auto* d_out =
630-
context.Output<framework::Tensor>(framework::GradVarName("Input"));
626+
} else {
627+
auto* d_input =
628+
context.Input<framework::Tensor>(framework::GradVarName("Out"));
629+
auto* d_out =
630+
context.Output<framework::Tensor>(framework::GradVarName("Input"));
631631

632-
d_out->mutable_data<T>(context.GetPlace());
632+
d_out->mutable_data<T>(context.GetPlace());
633633

634-
math::SetConstant<DeviceContext, T> set_zero;
635-
set_zero(dev_ctx, d_out, static_cast<T>(0));
634+
math::SetConstant<DeviceContext, T> set_zero;
635+
set_zero(dev_ctx, d_out, static_cast<T>(0));
636636

637-
auto in_dims = d_input->dims();
637+
auto in_dims = d_input->dims();
638638

639-
auto in_t =
640-
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
641-
*d_input);
642-
auto out_t =
643-
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
644-
*d_out, out_dims);
645-
if (need_reverse) {
646-
framework::Tensor reverse_input;
647-
reverse_input.mutable_data<T>(in_dims, context.GetPlace());
648-
auto reverse_in_t =
639+
auto in_t = framework::EigenTensor<T, D, Eigen::RowMajor,
640+
Eigen::DenseIndex>::From(*d_input);
641+
auto out_t =
649642
framework::EigenTensor<T, D, Eigen::RowMajor,
650-
Eigen::DenseIndex>::From(reverse_input);
651-
652-
reverse_in_t.device(place) = in_t.reverse(reverse_axis);
653-
out_t.stridedSlice(starts_indices, ends_indices, strides_indices)
654-
.device(place) = reverse_in_t;
655-
} else {
656-
out_t.stridedSlice(starts_indices, ends_indices, strides_indices)
657-
.device(place) = in_t;
643+
Eigen::DenseIndex>::From(*d_out, out_dims);
644+
if (need_reverse) {
645+
framework::Tensor reverse_input;
646+
reverse_input.mutable_data<T>(in_dims, context.GetPlace());
647+
auto reverse_in_t =
648+
framework::EigenTensor<T, D, Eigen::RowMajor,
649+
Eigen::DenseIndex>::From(reverse_input);
650+
651+
reverse_in_t.device(place) = in_t.reverse(reverse_axis);
652+
out_t.stridedSlice(starts_indices, ends_indices, strides_indices)
653+
.device(place) = reverse_in_t;
654+
} else {
655+
out_t.stridedSlice(starts_indices, ends_indices, strides_indices)
656+
.device(place) = in_t;
657+
}
658658
}
659659
}
660660
};

python/paddle/fluid/tests/unittests/dygraph_to_static/test_slice.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,8 @@ def test_set_value_with_save(self):
177177
output_spec=None)
178178

179179

180-
class TestSliceSupplementCase(unittest.TestCase):
180+
class TestSliceSupplementSpecialCase(unittest.TestCase):
181+
# unittest for slice index which abs(step)>0. eg: x[::2]
181182
def test_static_slice_step(self):
182183
paddle.enable_static()
183184
array = np.arange(4**3).reshape((4, 4, 4)).astype('int64')
@@ -242,6 +243,20 @@ def test_compare_paddle_strided_slice_with_numpy(self):
242243
np.array_equal(sl.numpy(), array[s2[0]:e2[0]:stride2[0], s2[1]:e2[
243244
1]:stride2[1]]))
244245

246+
array = np.arange(6 * 7 * 8).reshape((6, 7, 8))
247+
pt = paddle.to_tensor(array)
248+
s2 = [7, -1]
249+
e2 = [2, -5]
250+
stride2 = [-2, -3]
251+
sl = paddle.strided_slice(
252+
pt, axes=[0, 2], starts=s2, ends=e2, strides=stride2)
253+
254+
array_slice = array[s2[0]:e2[0]:stride2[0], ::, s2[1]:e2[1]:stride2[1]]
255+
self.assertTrue(
256+
np.array_equal(sl.numpy(), array_slice),
257+
msg="paddle.strided_slice:\n {} \n numpy slice:\n{}".format(
258+
sl.numpy(), array_slice))
259+
245260

246261
if __name__ == '__main__':
247262
unittest.main()

python/paddle/fluid/tests/unittests/test_strided_slice_op.py

Lines changed: 58 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -701,6 +701,49 @@ def create_case(self, net):
701701
msg="dygraph graph result:\n{} \nstatic dygraph result:\n{}".format(
702702
l1.numpy(), l2.numpy()))
703703

704+
def test_strided_slice_tensor_array_cuda_pinned_place(self):
705+
if paddle.device.is_compiled_with_cuda():
706+
with paddle.fluid.dygraph.guard():
707+
708+
class Simple(paddle.nn.Layer):
709+
def __init__(self):
710+
super(Simple, self).__init__()
711+
712+
def forward(self, inps):
713+
tensor_array = None
714+
for i, tensor in enumerate(inps):
715+
index = paddle.full(
716+
shape=[1], dtype='int64', fill_value=i)
717+
if tensor_array is None:
718+
tensor_array = paddle.tensor.array_write(
719+
tensor, i=index)
720+
else:
721+
paddle.tensor.array_write(
722+
tensor, i=index, array=tensor_array)
723+
724+
array1 = paddle.concat(tensor_array)
725+
array2 = paddle.concat(tensor_array[::-1])
726+
return array1 + array2 * array2
727+
728+
net = Simple()
729+
func = paddle.jit.to_static(net.forward)
730+
731+
inps1 = paddle.to_tensor(
732+
np.random.randn(2, 10),
733+
place=paddle.CUDAPinnedPlace(),
734+
stop_gradient=False)
735+
inps2 = paddle.to_tensor(
736+
np.random.randn(2, 10),
737+
place=paddle.CUDAPinnedPlace(),
738+
stop_gradient=False)
739+
740+
self.assertTrue(inps1.place.is_cuda_pinned_place())
741+
self.assertTrue(inps2.place.is_cuda_pinned_place())
742+
743+
result = func([inps1, inps2])
744+
745+
self.assertFalse(result.place.is_cuda_pinned_place())
746+
704747
def test_strided_slice_tensor_array(self):
705748
class Net(ArrayLayer):
706749
def array_slice(self, tensors):
@@ -854,28 +897,21 @@ def array_slice(self, tensors):
854897

855898
self.create_case(Net(input_size=112, array_size=13))
856899

857-
# TODO(weixin):Currently, the case that the start index is
858-
# less than `-array_size` is not supported.
859-
# The index parsed from the slice of the VarBase/Variable
860-
# is processed before being passed to `strided_slice_op`.
861-
# The slice may be processed uniformly, instead of
862-
# processing separately for TensorArray\VarBase\Variable.
863-
#
864-
# class Net(ArrayLayer):
865-
#
866-
# def array_slice(self,tensors):
867-
# return tensors[-60:20:3]
868-
# self.create_case(Net(input_size=112,array_size=13))
869-
870-
# class Net(ArrayLayer):
871-
# def array_slice(self, tensors):
872-
# return tensors[-3:-60:-3]
873-
874-
# self.create_case(Net(input_size=112, array_size=13))
875-
876-
# class Net(ArrayLayer):
877-
# def array_slice(self, tensors):
878-
# return tensors[-1:-60:-3]
900+
class Net(ArrayLayer):
901+
def array_slice(self, tensors):
902+
return tensors[-60:20:3]
903+
904+
self.create_case(Net(input_size=112, array_size=13))
905+
906+
class Net(ArrayLayer):
907+
def array_slice(self, tensors):
908+
return tensors[-3:-60:-3]
909+
910+
self.create_case(Net(input_size=112, array_size=13))
911+
912+
class Net(ArrayLayer):
913+
def array_slice(self, tensors):
914+
return tensors[-1:-60:-3]
879915

880916

881917
if __name__ == "__main__":

0 commit comments

Comments
 (0)