Skip to content

Commit 678a259

Browse files
authored
Support Multi-Stream, Single-Thread in New Executor (#35024)
* Modify into QueueSync QueueAsync * fix complie on MacOS * fix pointer * fix conflict * polish unittest * fix windows fetch error * polish code according reviewer * fix device_guard on CPU place
1 parent d1a33bc commit 678a259

File tree

9 files changed

+830
-33
lines changed

9 files changed

+830
-33
lines changed

paddle/fluid/framework/new_executor/interpretercore.cc

Lines changed: 265 additions & 25 deletions
Large diffs are not rendered by default.

paddle/fluid/framework/new_executor/interpretercore.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "paddle/fluid/framework/program_desc.h"
2525
#include "paddle/fluid/framework/tensor.h"
2626
#include "paddle/fluid/framework/variable.h"
27+
#include "paddle/fluid/platform/event.h"
2728

2829
namespace paddle {
2930
namespace framework {
@@ -63,9 +64,22 @@ class InterpreterCore {
6364
void BuildVariableScope(const framework::ProgramDesc& pdesc,
6465
VariableScope* var_scope);
6566

67+
platform::DeviceContext* ParseDeviceContextForInstruction(
68+
const OpFuncNode& op_func_node, const OperatorBase& op_base);
69+
70+
void RecordEventInstruction(const Instruction& instruction,
71+
const OpFuncNode& op_func_node);
72+
73+
void WaitOrSync(const std::vector<EventInter>& events,
74+
const platform::DeviceContext* dev_ctx);
75+
76+
void StreamWaitEventOrSync(const Instruction& instruction);
77+
6678
const platform::Place& place_;
6779
ProgramDesc main_program_;
6880
VariableScope* global_scope_;
81+
platform::DeviceContextPool d2h_ctx_pool_;
82+
platform::DeviceContextPool h2d_ctx_pool_;
6983
std::vector<VariableMetaInfo> vec_meta_info_;
7084

7185
std::vector<paddle::framework::OpFuncNode> vec_func_list_;
@@ -80,6 +94,7 @@ class InterpreterCore {
8094
bool is_build_;
8195

8296
std::vector<std::string> feed_names_;
97+
std::map<size_t, std::shared_ptr<platform::CudaEvent>> var_id2event_;
8398

8499
platform::DeviceContextPool fetch_context_pool_;
85100
};

paddle/fluid/framework/new_executor/new_executor_defs.h

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <vector>
2020

2121
#include "paddle/fluid/framework/operator.h"
22+
#include "paddle/fluid/platform/event.h"
2223

2324
namespace paddle {
2425
namespace framework {
@@ -41,22 +42,30 @@ struct VariableScope {
4142
std::map<std::string, int> name2id;
4243
};
4344

45+
struct EventRun {
46+
explicit EventRun(size_t op_id) : op_id_(op_id) {}
47+
size_t op_id_;
48+
};
4449
struct NextInstruction {
4550
std::vector<size_t> direct_run_;
51+
std::vector<EventRun> event_wait_run_;
52+
std::vector<EventRun> synchronize_run_;
53+
std::vector<size_t> all_next_ops_;
4654
};
4755

48-
struct EventInter {};
56+
struct EventInter {
57+
explicit EventInter(size_t var_id, std::shared_ptr<platform::CudaEvent> event,
58+
bool is_sync)
59+
: var_id_(var_id), event_(event), is_sync_(is_sync) {}
60+
size_t var_id_;
61+
std::shared_ptr<platform::CudaEvent> event_;
62+
bool is_sync_;
63+
};
4964

5065
struct InstructionInfo {
5166
std::vector<size_t> dependecy_count_;
5267
};
5368

54-
struct EventRun {
55-
EventInter event_inter;
56-
std::vector<size_t> same_device_run_;
57-
std::vector<size_t> synchronized_run;
58-
};
59-
6069
struct Instruction {
6170
OpKernelFunc kernel_func_;
6271
std::shared_ptr<RuntimeContext> runtime_ctx_;
@@ -67,7 +76,16 @@ struct Instruction {
6776

6877
std::vector<size_t> gc_check_var_list;
6978
NextInstruction next_instruction_;
70-
std::vector<EventInter> vec_event_list_;
79+
80+
std::vector<EventInter> intput_events_;
81+
std::vector<EventInter> output_events_;
82+
83+
platform::DeviceContext* dev_ctx_; // not owned
84+
};
85+
86+
enum class OpFuncType {
87+
kQueueAsync, // GPU Kernel or d2h, h2d, send, recv, broadcast
88+
kQueueSync, // CPU kernel, block host
7189
};
7290

7391
struct OpFuncNode {
@@ -76,6 +94,8 @@ struct OpFuncNode {
7694
std::map<std::string, std::vector<int>> output_index;
7795

7896
OpKernelComputeFunc kernel_func_;
97+
platform::DeviceContext* dev_ctx_; // not owned
98+
OpFuncType type_;
7999
};
80100

81101
} // namespace framework

paddle/fluid/framework/new_executor/standalone_executor.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ std::shared_ptr<InterpreterCore> StandaloneExecutor::GetInterpreterCore(
9191
auto iter = interpretercores_.find(oss.str());
9292

9393
if (iter == interpretercores_.end()) {
94+
VLOG(3) << "create interpreter_core for " << oss.str();
9495
auto core = std::make_shared<InterpreterCore>(
9596
place_, main_prog_, &global_scope_, feed_names, fetch_names);
9697
interpretercores_.emplace(oss.str(), core);
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#include "paddle/fluid/operators/memcpy_d2h_op.h"
13+
14+
#include <string>
15+
16+
namespace paddle {
17+
namespace framework {
18+
class OpDesc;
19+
class InferShapeContext;
20+
template <typename T>
21+
class EmptyGradOpMaker;
22+
} // namespace framework
23+
namespace imperative {
24+
class OpBase;
25+
} // namespace imperative
26+
namespace platform {
27+
struct CPUPlace;
28+
struct CUDAPlace;
29+
struct float16;
30+
} // namespace platform
31+
} // namespace paddle
32+
33+
namespace paddle {
34+
namespace operators {
35+
36+
class MemcpyD2HOp : public framework::OperatorWithKernel {
37+
public:
38+
using framework::OperatorWithKernel::OperatorWithKernel;
39+
40+
void InferShape(framework::InferShapeContext *ctx) const override {
41+
auto type = ctx->GetInputsVarType("X")[0];
42+
if (type == framework::proto::VarType::SELECTED_ROWS ||
43+
type == framework::proto::VarType::LOD_TENSOR) {
44+
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
45+
if (type == framework::proto::VarType::LOD_TENSOR) {
46+
ctx->ShareLoD("X", /*->*/ "Out");
47+
}
48+
}
49+
}
50+
51+
protected:
52+
framework::OpKernelType GetKernelTypeForVar(
53+
const std::string &var_name, const framework::Tensor &tensor,
54+
const framework::OpKernelType &expected_kernel_type) const override {
55+
return framework::OpKernelType(expected_kernel_type.data_type_,
56+
expected_kernel_type.place_,
57+
tensor.layout());
58+
}
59+
60+
framework::OpKernelType GetExpectedKernelType(
61+
const framework::ExecutionContext &ctx) const override {
62+
return framework::OpKernelType(
63+
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
64+
ctx.device_context());
65+
}
66+
};
67+
68+
class MemcpyD2HInferVarType : public framework::VarTypeInference {
69+
public:
70+
void operator()(framework::InferVarTypeContext *ctx) const override {
71+
ctx->SyncTypeAndDataType("X", "Out");
72+
}
73+
};
74+
75+
class MemcpyD2HKernel {
76+
public:
77+
void operator()(const framework::ExecutionContext &ctx) const {
78+
auto *x = ctx.InputVar("X");
79+
if (x == nullptr) {
80+
return;
81+
}
82+
PADDLE_ENFORCE_EQ(ctx.HasOutput("Out"), true,
83+
platform::errors::NotFound(
84+
"Output(Out) of memcpy_d2h_op is not found."));
85+
auto *out = ctx.OutputVar("Out");
86+
// Get dev_ctx from ExecutionContext, it's D2H stream
87+
auto &dev_ctx = ctx.device_context();
88+
auto dst_place_type = ctx.Attr<int>("dst_place_type");
89+
framework::VisitVarType(*x, MemcpyD2HFunctor(out, dev_ctx, dst_place_type));
90+
}
91+
};
92+
93+
class MemcpyD2HOpProtoMaker : public framework::OpProtoAndCheckerMaker {
94+
public:
95+
void Make() override {
96+
AddInput("X", "(LoDTensor) The input variable ");
97+
AddOutput("Out",
98+
"(LoDTensor) The type of output "
99+
"is the same as input X.");
100+
AddAttr<int>(
101+
"dst_place_type",
102+
"Determine the dst place of tensor copy. "
103+
"By Now it ONLY support NPUPlace/CUDAPlace <-> CUDAPinnedPlace/CPU"
104+
"Other place type is Unimplemented and will cause ERROR."
105+
"0: dst is on CPUPlace. "
106+
"1: dst is on CUDAPinnedPlace. ");
107+
AddComment(R"DOC(
108+
MemcpyD2H Operator.
109+
By now, it ONLY supports the memcopy between NPUPlace/CUDAPlace <-> CUDAPinnedPlace/CPU.
110+
You would have to update it if you want other more capacities.
111+
Out = X, when type in [LoDTensor]
112+
raise error if the type is not listed above.
113+
)DOC");
114+
}
115+
};
116+
117+
} // namespace operators
118+
} // namespace paddle
119+
120+
namespace ops = paddle::operators;
121+
namespace plat = paddle::platform;
122+
REGISTER_OPERATOR(
123+
memcpy_d2h, ops::MemcpyD2HOp, ops::MemcpyD2HOpProtoMaker,
124+
ops::MemcpyD2HInferVarType,
125+
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
126+
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
127+
128+
REGISTER_OP_CPU_KERNEL_FUNCTOR(memcpy_d2h, float, ops::MemcpyD2HKernel, double,
129+
ops::MemcpyD2HKernel, int, ops::MemcpyD2HKernel,
130+
int64_t, ops::MemcpyD2HKernel, bool,
131+
ops::MemcpyD2HKernel, plat::float16,
132+
ops::MemcpyD2HKernel);
133+
134+
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_ROCM)
135+
REGISTER_OP_CUDA_KERNEL_FUNCTOR(memcpy_d2h, float, ops::MemcpyD2HKernel, double,
136+
ops::MemcpyD2HKernel, int, ops::MemcpyD2HKernel,
137+
int64_t, ops::MemcpyD2HKernel, bool,
138+
ops::MemcpyD2HKernel, plat::float16,
139+
ops::MemcpyD2HKernel);
140+
#endif
141+
142+
#ifdef PADDLE_WITH_ASCEND_CL
143+
REGISTER_OP_NPU_KERNEL_FUNCTOR(memcpy_d2h, float, ops::MemcpyD2HKernel, double,
144+
ops::MemcpyD2HKernel, int, ops::MemcpyD2HKernel,
145+
int64_t, ops::MemcpyD2HKernel, bool,
146+
ops::MemcpyD2HKernel, plat::float16,
147+
ops::MemcpyD2HKernel);
148+
#endif
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#pragma once
13+
14+
#include "paddle/fluid/framework/data_type.h"
15+
#include "paddle/fluid/framework/op_registry.h"
16+
#include "paddle/fluid/framework/var_type.h"
17+
#include "paddle/fluid/platform/device_context.h"
18+
19+
namespace paddle {
20+
namespace platform {
21+
class DeviceContext;
22+
} // namespace platform
23+
} // namespace paddle
24+
25+
namespace paddle {
26+
namespace framework {
27+
class LoDTensor;
28+
class Variable;
29+
class SelectedRows;
30+
} // namespace framework
31+
} // namespace paddle
32+
33+
namespace paddle {
34+
namespace operators {
35+
class MemcpyD2HFunctor {
36+
public:
37+
MemcpyD2HFunctor(framework::Variable *out,
38+
const platform::DeviceContext &dev_ctx,
39+
const int dst_place_type)
40+
: out_(out), dev_ctx_(dev_ctx), dst_place_type_(dst_place_type) {}
41+
42+
void operator()(const framework::LoDTensor &lod_tensor) const {
43+
auto &out_tensor = *out_->GetMutable<framework::LoDTensor>();
44+
45+
if (dst_place_type_ == 1) {
46+
framework::TensorCopy(lod_tensor, platform::CUDAPinnedPlace(), dev_ctx_,
47+
&out_tensor);
48+
} else if (dst_place_type_ == 0) {
49+
framework::TensorCopySync(lod_tensor, platform::CPUPlace(), &out_tensor);
50+
} else {
51+
PADDLE_THROW(platform::errors::Unimplemented(
52+
"memcpy dst_place_type: %d is not supported yet.", dst_place_type_));
53+
}
54+
out_tensor.set_lod(lod_tensor.lod());
55+
}
56+
57+
void operator()(const framework::SelectedRows &rows) const {
58+
// (JZ-LIANG) to support SelectedRows
59+
PADDLE_THROW(platform::errors::Unimplemented(
60+
"Memcpy for SelectedRows is NOT support yet."));
61+
}
62+
63+
template <typename T>
64+
void operator()(const T &v) const {
65+
PADDLE_ENFORCE_EQ(
66+
true, false,
67+
platform::errors::PermissionDenied(
68+
"Not support type for Memcpy op with type %s", typeid(T).name()));
69+
}
70+
71+
private:
72+
framework::Variable *out_;
73+
const platform::DeviceContext &dev_ctx_;
74+
const int dst_place_type_;
75+
};
76+
77+
} // namespace operators
78+
} // namespace paddle

0 commit comments

Comments
 (0)