Skip to content

Commit c56d697

Browse files
modify fetch logic, use D2H Stream (#35191)
* modify fetch logic, use D2H Stream, test=develop * refine, test=develop
1 parent 7743cdf commit c56d697

File tree

3 files changed

+34
-60
lines changed

3 files changed

+34
-60
lines changed

paddle/fluid/framework/new_executor/interpretercore.cc

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,7 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
143143
main_program_(main_prog),
144144
global_scope_(global_scope),
145145
d2h_ctx_pool_({place}),
146-
h2d_ctx_pool_({place}),
147-
fetch_context_pool_({place}) {
146+
h2d_ctx_pool_({place}) {
148147
is_build_ = false;
149148

150149
garbages_.reset(new GarbageQueue());
@@ -339,9 +338,6 @@ void InterpreterCore::BuildInstructionCtx(Instruction* instr_node,
339338
new RuntimeInferShapeContext(*op_base, *instr_node->runtime_ctx_.get()));
340339

341340
auto* dev_ctx = instr_node->dev_ctx_;
342-
if (instr_node->kernel_func_.operator_base_->Type() == "fetch_v2") {
343-
dev_ctx = fetch_context_pool_.Get(place);
344-
}
345341
Scope scope;
346342

347343
instr_node->execution_ctx_.reset(new ExecutionContext(
@@ -356,12 +352,6 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
356352
instr_node.kernel_func_.operator_base_)
357353
->InferShape(instr_node.infershape_ctx_.get());
358354

359-
if (instr_node.kernel_func_.operator_base_->Type() == "fetch_v2") {
360-
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
361-
auto* dev_ctx = pool.Get(place_);
362-
dev_ctx->Wait(); // TODO(wanghuancoder)
363-
}
364-
365355
instr_node.kernel_func_.compute_func_(*instr_node.execution_ctx_.get());
366356
}
367357

@@ -411,8 +401,6 @@ void InterpreterCore::ExecuteInstructionList(
411401
working_var_ref);
412402
}
413403

414-
fetch_context_pool_.Get(place)->Wait();
415-
416404
for (size_t i = 0; i < working_var_ref.size(); ++i) {
417405
if (working_var_ref[i].var_ref_count_ != 0) {
418406
std::cerr << " var ref is not zero " << i << std::endl;
@@ -671,6 +659,9 @@ void InterpreterCore::BuildOpFuncList(const platform::Place& place,
671659
expected_kernel_key);
672660
if (!platform::is_same_place(kernel_type_for_var.place_,
673661
expected_kernel_key.place_)) {
662+
if (op_base->Type() == "fetch_v2") {
663+
op_base->SetAttr("deepcopy", false);
664+
}
674665
// need trans place
675666
// 1. add var in scope
676667
// 2. add copy op

paddle/fluid/framework/new_executor/interpretercore.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,6 @@ class InterpreterCore {
114114
size_t max_memory_size_;
115115
size_t cur_memory_size_;
116116
std::unique_ptr<WorkQueue> gc_queue_;
117-
118-
platform::DeviceContextPool fetch_context_pool_;
119117
};
120118
} // namespace framework
121119
} // namespace paddle

paddle/fluid/operators/controlflow/fetch_v2_op.cc

Lines changed: 30 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,9 @@ struct float16;
3636
namespace paddle {
3737
namespace operators {
3838

39-
static void DataCopy(const framework::LoDTensor &src_item,
39+
static void DeepCopy(const framework::LoDTensor &src_item,
4040
const std::string &fetch_var_name,
41-
framework::LoDTensor *dst_item,
42-
const platform::DeviceContext &dev_ctx) {
41+
framework::LoDTensor *dst_item) {
4342
if (src_item.IsInitialized() && src_item.numel() > 0) {
4443
#ifdef PADDLE_WITH_MKLDNN
4544
// Conversion from MKL-DNN to Paddle
@@ -53,26 +52,13 @@ static void DataCopy(const framework::LoDTensor &src_item,
5352
: paddle::platform::MKLDNNDeviceContext::tls()
5453
.get_cur_paddle_data_layout(),
5554
src_item, &out, platform::CPUPlace());
56-
TensorCopy(src_item, platform::CPUPlace(), dev_ctx, dst_item);
55+
TensorCopySync(out, platform::CPUPlace(), dst_item);
5756
} else {
58-
if (platform::is_gpu_place(src_item.place())) {
59-
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
60-
TensorCopy(src_item, platform::CUDAPinnedPlace(), dev_ctx, dst_item);
61-
#endif
62-
} else {
63-
TensorCopy(src_item, platform::CPUPlace(), dst_item);
64-
}
57+
TensorCopySync(src_item, platform::CPUPlace(), dst_item);
6558
}
6659
#else
67-
if (platform::is_gpu_place(src_item.place())) {
68-
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
69-
TensorCopy(src_item, platform::CUDAPinnedPlace(), dev_ctx, dst_item);
60+
TensorCopySync(src_item, platform::CPUPlace(), dst_item);
7061
#endif
71-
} else {
72-
TensorCopy(src_item, platform::CPUPlace(), dst_item);
73-
}
74-
#endif
75-
7662
} else {
7763
// Not copy, if the src tensor is empty.
7864
dst_item->clear();
@@ -92,15 +78,14 @@ class FetchV2Op : public framework::OperatorWithKernel {
9278
const std::string &var_name, const framework::Tensor &tensor,
9379
const framework::OpKernelType &expected_kernel_type) const override {
9480
return framework::OpKernelType(expected_kernel_type.data_type_,
95-
expected_kernel_type.place_,
96-
tensor.layout());
81+
tensor.place(), tensor.layout());
9782
}
9883

9984
framework::OpKernelType GetExpectedKernelType(
10085
const framework::ExecutionContext &ctx) const override {
10186
return framework::OpKernelType(
10287
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
103-
ctx.device_context());
88+
platform::CPUPlace());
10489
}
10590
};
10691

@@ -119,12 +104,10 @@ class FetchV2Kernel {
119104
if (fetch_var == nullptr) {
120105
return;
121106
}
122-
PADDLE_ENFORCE_EQ(ctx.HasOutput("Out"), true,
123-
platform::errors::NotFound(
124-
"Output(Out) of memcpy_d2h_op is not found."));
107+
PADDLE_ENFORCE_EQ(
108+
ctx.HasOutput("Out"), true,
109+
platform::errors::NotFound("Output(Out) of fetch_v2_op is not found."));
125110
auto *out_var = ctx.OutputVar("Out");
126-
// Get dev_ctx from ExecutionContext, it's D2H stream
127-
auto &dev_ctx = ctx.device_context();
128111

129112
int col = ctx.Attr<int>("col");
130113
PADDLE_ENFORCE_GE(
@@ -140,18 +123,34 @@ class FetchV2Kernel {
140123
fetch_list->resize(col + 1);
141124
}
142125

126+
bool deepcopy = ctx.Attr<bool>("deepcopy");
127+
143128
if (fetch_var->IsType<framework::LoDTensor>()) {
144129
auto &src_item = fetch_var->Get<framework::LoDTensor>();
145130
auto *dst_item = &(BOOST_GET(framework::LoDTensor, fetch_list->at(col)));
146-
DataCopy(src_item, fetch_var_name, dst_item, dev_ctx);
131+
PADDLE_ENFORCE_EQ(platform::is_cpu_place(src_item.place()), true,
132+
platform::errors::InvalidArgument(
133+
"Tensor's place of input(X) must be CPUPlace."));
134+
if (deepcopy) {
135+
DeepCopy(src_item, fetch_var_name, dst_item);
136+
} else {
137+
dst_item->ShareDataWith(src_item);
138+
}
147139
} else {
148140
auto &src_item = fetch_var->Get<framework::LoDTensorArray>();
149141
framework::LoDTensorArray tmp(src_item.size());
150142
fetch_list->at(col) = tmp;
151143
auto &dst_item =
152144
BOOST_GET(framework::LoDTensorArray, fetch_list->at(col));
153145
for (size_t i = 0; i < src_item.size(); ++i) {
154-
DataCopy(src_item[i], fetch_var_name, &dst_item[i], dev_ctx);
146+
PADDLE_ENFORCE_EQ(platform::is_cpu_place(src_item[i].place()), true,
147+
platform::errors::InvalidArgument(
148+
"Tensor's place of input(X) must be CPUPlace."));
149+
if (deepcopy) {
150+
DeepCopy(src_item[i], fetch_var_name, &dst_item[i]);
151+
} else {
152+
dst_item[i].ShareDataWith(src_item[i]);
153+
}
155154
}
156155
}
157156
}
@@ -167,6 +166,8 @@ class FetchV2OpProtoMaker : public framework::OpProtoAndCheckerMaker {
167166
"(vector<LoDTensor>) A fetching list of LoDTensor which may have "
168167
"different dimension, shape and data type.");
169168
AddAttr<int>("col", "(int) The column index of fetching object.");
169+
AddAttr<bool>("deepcopy", "(bool) Whether deep copy is required.")
170+
.SetDefault(true);
170171
AddComment(R"DOC(
171172
FetchV2 Operator.
172173
@@ -192,19 +193,3 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(fetch_v2, float, ops::FetchV2Kernel, double,
192193
int64_t, ops::FetchV2Kernel, bool,
193194
ops::FetchV2Kernel, plat::float16,
194195
ops::FetchV2Kernel);
195-
196-
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_ROCM)
197-
REGISTER_OP_CUDA_KERNEL_FUNCTOR(fetch_v2, float, ops::FetchV2Kernel, double,
198-
ops::FetchV2Kernel, int, ops::FetchV2Kernel,
199-
int64_t, ops::FetchV2Kernel, bool,
200-
ops::FetchV2Kernel, plat::float16,
201-
ops::FetchV2Kernel);
202-
#endif
203-
204-
#ifdef PADDLE_WITH_ASCEND_CL
205-
REGISTER_OP_NPU_KERNEL_FUNCTOR(fetch_v2, float, ops::FetchV2Kernel, double,
206-
ops::FetchV2Kernel, int, ops::FetchV2Kernel,
207-
int64_t, ops::FetchV2Kernel, bool,
208-
ops::FetchV2Kernel, plat::float16,
209-
ops::FetchV2Kernel);
210-
#endif

0 commit comments

Comments
 (0)