@@ -36,10 +36,9 @@ struct float16;
3636namespace paddle {
3737namespace operators {
3838
39- static void DataCopy (const framework::LoDTensor &src_item,
39+ static void DeepCopy (const framework::LoDTensor &src_item,
4040 const std::string &fetch_var_name,
41- framework::LoDTensor *dst_item,
42- const platform::DeviceContext &dev_ctx) {
41+ framework::LoDTensor *dst_item) {
4342 if (src_item.IsInitialized () && src_item.numel () > 0 ) {
4443#ifdef PADDLE_WITH_MKLDNN
4544 // Conversion from MKL-DNN to Paddle
@@ -53,26 +52,13 @@ static void DataCopy(const framework::LoDTensor &src_item,
5352 : paddle::platform::MKLDNNDeviceContext::tls ()
5453 .get_cur_paddle_data_layout (),
5554 src_item, &out, platform::CPUPlace ());
56- TensorCopy (src_item , platform::CPUPlace (), dev_ctx , dst_item);
55+ TensorCopySync (out , platform::CPUPlace (), dst_item);
5756 } else {
58- if (platform::is_gpu_place (src_item.place ())) {
59- #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
60- TensorCopy (src_item, platform::CUDAPinnedPlace (), dev_ctx, dst_item);
61- #endif
62- } else {
63- TensorCopy (src_item, platform::CPUPlace (), dst_item);
64- }
57+ TensorCopySync (src_item, platform::CPUPlace (), dst_item);
6558 }
6659#else
67- if (platform::is_gpu_place (src_item.place ())) {
68- #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
69- TensorCopy (src_item, platform::CUDAPinnedPlace (), dev_ctx, dst_item);
60+ TensorCopySync (src_item, platform::CPUPlace (), dst_item);
7061#endif
71- } else {
72- TensorCopy (src_item, platform::CPUPlace (), dst_item);
73- }
74- #endif
75-
7662 } else {
7763 // Not copy, if the src tensor is empty.
7864 dst_item->clear ();
@@ -92,15 +78,14 @@ class FetchV2Op : public framework::OperatorWithKernel {
9278 const std::string &var_name, const framework::Tensor &tensor,
9379 const framework::OpKernelType &expected_kernel_type) const override {
9480 return framework::OpKernelType (expected_kernel_type.data_type_ ,
95- expected_kernel_type.place_ ,
96- tensor.layout ());
81+ tensor.place (), tensor.layout ());
9782 }
9883
9984 framework::OpKernelType GetExpectedKernelType (
10085 const framework::ExecutionContext &ctx) const override {
10186 return framework::OpKernelType (
10287 OperatorWithKernel::IndicateVarDataType (ctx, " X" ),
103- ctx. device_context ());
88+ platform::CPUPlace ());
10489 }
10590};
10691
@@ -119,12 +104,10 @@ class FetchV2Kernel {
119104 if (fetch_var == nullptr ) {
120105 return ;
121106 }
122- PADDLE_ENFORCE_EQ (ctx. HasOutput ( " Out " ), true ,
123- platform::errors::NotFound (
124- " Output(Out) of memcpy_d2h_op is not found." ));
107+ PADDLE_ENFORCE_EQ (
108+ ctx. HasOutput ( " Out " ), true ,
109+ platform::errors::NotFound ( " Output(Out) of fetch_v2_op is not found." ));
125110 auto *out_var = ctx.OutputVar (" Out" );
126- // Get dev_ctx from ExecutionContext, it's D2H stream
127- auto &dev_ctx = ctx.device_context ();
128111
129112 int col = ctx.Attr <int >(" col" );
130113 PADDLE_ENFORCE_GE (
@@ -140,18 +123,34 @@ class FetchV2Kernel {
140123 fetch_list->resize (col + 1 );
141124 }
142125
126+ bool deepcopy = ctx.Attr <bool >(" deepcopy" );
127+
143128 if (fetch_var->IsType <framework::LoDTensor>()) {
144129 auto &src_item = fetch_var->Get <framework::LoDTensor>();
145130 auto *dst_item = &(BOOST_GET (framework::LoDTensor, fetch_list->at (col)));
146- DataCopy (src_item, fetch_var_name, dst_item, dev_ctx);
131+ PADDLE_ENFORCE_EQ (platform::is_cpu_place (src_item.place ()), true ,
132+ platform::errors::InvalidArgument (
133+ " Tensor's place of input(X) must be CPUPlace." ));
134+ if (deepcopy) {
135+ DeepCopy (src_item, fetch_var_name, dst_item);
136+ } else {
137+ dst_item->ShareDataWith (src_item);
138+ }
147139 } else {
148140 auto &src_item = fetch_var->Get <framework::LoDTensorArray>();
149141 framework::LoDTensorArray tmp (src_item.size ());
150142 fetch_list->at (col) = tmp;
151143 auto &dst_item =
152144 BOOST_GET (framework::LoDTensorArray, fetch_list->at (col));
153145 for (size_t i = 0 ; i < src_item.size (); ++i) {
154- DataCopy (src_item[i], fetch_var_name, &dst_item[i], dev_ctx);
146+ PADDLE_ENFORCE_EQ (platform::is_cpu_place (src_item[i].place ()), true ,
147+ platform::errors::InvalidArgument (
148+ " Tensor's place of input(X) must be CPUPlace." ));
149+ if (deepcopy) {
150+ DeepCopy (src_item[i], fetch_var_name, &dst_item[i]);
151+ } else {
152+ dst_item[i].ShareDataWith (src_item[i]);
153+ }
155154 }
156155 }
157156 }
@@ -167,6 +166,8 @@ class FetchV2OpProtoMaker : public framework::OpProtoAndCheckerMaker {
167166 " (vector<LoDTensor>) A fetching list of LoDTensor which may have "
168167 " different dimension, shape and data type." );
169168 AddAttr<int >(" col" , " (int) The column index of fetching object." );
169+ AddAttr<bool >(" deepcopy" , " (bool) Whether deep copy is required." )
170+ .SetDefault (true );
170171 AddComment (R"DOC(
171172FetchV2 Operator.
172173
@@ -192,19 +193,3 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(fetch_v2, float, ops::FetchV2Kernel, double,
192193 int64_t , ops::FetchV2Kernel, bool ,
193194 ops::FetchV2Kernel, plat::float16,
194195 ops::FetchV2Kernel);
195-
196- #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_ROCM)
197- REGISTER_OP_CUDA_KERNEL_FUNCTOR (fetch_v2, float , ops::FetchV2Kernel, double ,
198- ops::FetchV2Kernel, int , ops::FetchV2Kernel,
199- int64_t , ops::FetchV2Kernel, bool ,
200- ops::FetchV2Kernel, plat::float16,
201- ops::FetchV2Kernel);
202- #endif
203-
204- #ifdef PADDLE_WITH_ASCEND_CL
205- REGISTER_OP_NPU_KERNEL_FUNCTOR (fetch_v2, float , ops::FetchV2Kernel, double ,
206- ops::FetchV2Kernel, int , ops::FetchV2Kernel,
207- int64_t , ops::FetchV2Kernel, bool ,
208- ops::FetchV2Kernel, plat::float16,
209- ops::FetchV2Kernel);
210- #endif
0 commit comments