Skip to content

Commit 98e3ef3

Browse files
Merge branch 'develop' into concat_kernel_latest
2 parents 9e88d11 + 8784ec6 commit 98e3ef3

File tree

96 files changed

+3410
-872
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

96 files changed

+3410
-872
lines changed

paddle/fluid/distributed/fleet_executor/dist_model.cc

Lines changed: 214 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,11 @@
1515
#include <glog/logging.h>
1616

1717
#include "paddle/fluid/distributed/fleet_executor/dist_model.h"
18+
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
19+
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
20+
#include "paddle/fluid/framework/block_desc.h"
1821
#include "paddle/fluid/framework/naive_executor.h"
22+
#include "paddle/fluid/framework/op_proto_maker.h"
1923
#include "paddle/fluid/framework/program_desc.h"
2024
#include "paddle/fluid/framework/scope.h"
2125
#include "paddle/fluid/framework/tensor.h"
@@ -37,24 +41,179 @@ bool IsPersistable(const framework::VarDesc *var) {
3741

3842
bool DistModel::Init() {
3943
/* TODO(fleet exe dev): implement this funct */
40-
place_ = paddle::platform::CUDAPlace(config_.device_id);
41-
if (!PrepareScope()) {
44+
bool init_method = (!config_.model_dir.empty() || config_.program_desc);
45+
PADDLE_ENFORCE_EQ(init_method, true,
46+
platform::errors::InvalidArgument(
47+
"One of model dir or program desc must be provided to "
48+
"dist model inference."));
49+
if (config_.program_desc) {
50+
PADDLE_ENFORCE_NOT_NULL(
51+
config_.scope, platform::errors::InvalidArgument(
52+
"Scope must be provided to dist model inference if "
53+
"program desc has been provided."));
54+
}
55+
if (!PreparePlace()) {
4256
return false;
4357
}
44-
if (!PrepareProgram()) {
58+
if (!config_.program_desc) {
59+
if (config_.scope) {
60+
LOG(WARNING) << "The provided scope will be ignored if model dir has "
61+
"also been provided.";
62+
}
63+
if (!PrepareScope()) {
64+
return false;
65+
}
66+
if (!PrepareProgram()) {
67+
return false;
68+
}
69+
} else {
70+
program_.reset(config_.program_desc);
71+
scope_.reset(config_.scope);
72+
}
73+
if (!PrepareFeedAndFetch()) {
4574
return false;
4675
}
4776
if (!CommInit()) {
4877
return false;
4978
}
79+
if (!PrepareFleetExe()) {
80+
return false;
81+
}
82+
return true;
83+
}
84+
85+
bool DistModel::PreparePlace() {
86+
if (config_.place == "GPU") {
87+
place_ = paddle::platform::CUDAPlace(config_.device_id);
88+
} else if (config_.place == "CPU") {
89+
place_ = paddle::platform::CPUPlace();
90+
} else {
91+
PADDLE_THROW(platform::errors::InvalidArgument(
92+
"Place must be choosen from GPU or CPU, but got %s.", config_.place));
93+
}
5094
return true;
5195
}
5296

5397
bool DistModel::CommInit() {
54-
// TODO(fleet executor): init the comm
98+
// NOTE (Yuang Liu): The peer endpoints will be obtained with the assumption
99+
// that mp part is always on inner side and pp part is always on outer side.
100+
// TODO(fleet exe dev): The peer endpoints could be configured by users.
101+
PADDLE_ENFORCE_EQ(
102+
config_.pp_degree * config_.mp_degree, config_.nranks,
103+
platform::errors::InvalidArgument(
104+
"The mp_degree multiplies pp_degree is not equal with nranks"));
105+
std::unique_ptr<framework::ProgramDesc> comm_init_program(
106+
new framework::ProgramDesc());
107+
framework::BlockDesc *comm_init_block = comm_init_program->MutableBlock(0);
108+
if (config_.mp_degree > 1) {
109+
PADDLE_ENFORCE_GE(
110+
config_.mp_ring_id, 0,
111+
platform::errors::InvalidArgument(
112+
"mp ring id must be provided for inference under mp."));
113+
VLOG(3) << "Init comm group for mp.";
114+
std::vector<std::string> peer_endpoints;
115+
for (int64_t
116+
idx = (config_.local_rank / config_.mp_degree) * config_.mp_degree,
117+
i = 0;
118+
i < config_.mp_degree; ++idx, ++i) {
119+
if (config_.trainer_endpoints[idx] == config_.current_endpoint) {
120+
continue;
121+
}
122+
peer_endpoints.emplace_back(config_.trainer_endpoints[idx]);
123+
}
124+
// get nranks in a mp group and inner group rank for local rank
125+
int64_t mp_group_nranks = config_.nranks / config_.pp_degree;
126+
int64_t mp_group_rank = config_.local_rank % config_.mp_degree;
127+
InsertCommOp("mp_comm_id", mp_group_nranks, mp_group_rank, peer_endpoints,
128+
comm_init_block, config_.mp_ring_id);
129+
}
130+
if (config_.pp_degree) {
131+
// NOTE: the last pp stage doesn't need init pp comm
132+
VLOG(3) << "Init comm group for pp.";
133+
if (config_.local_rank - config_.mp_degree >= 0) {
134+
PADDLE_ENFORCE_EQ(config_.pp_upstream_ring_id >= 0, true,
135+
platform::errors::InvalidArgument(
136+
"pp upstream ring id must be provided for "
137+
"non-first pp stage if inference under pp."));
138+
// not the first pp stage, has upstream
139+
std::vector<std::string> upstream_peer_endpoints;
140+
upstream_peer_endpoints.emplace_back(
141+
config_.trainer_endpoints[config_.local_rank - config_.mp_degree]);
142+
InsertCommOp("pp_upstream_comm_id", 2, 1, upstream_peer_endpoints,
143+
comm_init_block, config_.pp_upstream_ring_id);
144+
}
145+
146+
if (config_.local_rank + config_.mp_degree < config_.nranks) {
147+
PADDLE_ENFORCE_EQ(config_.pp_downstream_ring_id >= 0, true,
148+
platform::errors::InvalidArgument(
149+
"pp downstream ring id must be provided for "
150+
"non-last pp stage if inference under pp."));
151+
// not the last pp stage, has downstream
152+
std::vector<std::string> downstream_peer_endpoints;
153+
downstream_peer_endpoints.emplace_back(
154+
config_.trainer_endpoints[config_.local_rank + config_.mp_degree]);
155+
InsertCommOp("pp_downstream_comm_id", 2, 0, downstream_peer_endpoints,
156+
comm_init_block, config_.pp_downstream_ring_id);
157+
}
158+
}
159+
framework::NaiveExecutor e(place_);
160+
e.CreateVariables(*comm_init_program, 0, true, scope_.get());
161+
e.Prepare(scope_.get(), *comm_init_program, 0, false);
162+
e.Run();
163+
VLOG(3) << "Comm init successful.";
55164
return true;
56165
}
57166

167+
void DistModel::InsertCommOp(std::string tmp_var_name, int nranks, int rank,
168+
const std::vector<std::string> &peer_endpoints,
169+
framework::BlockDesc *block, int ring_id) {
170+
/*
171+
* tmp_var_name: the var name for var comm_id
172+
* nranks: number of total ranks
173+
* rank: the rank of local rank in the comm group
174+
* peer_endpoints: peer's endpoints
175+
* block: the block where to insert the comm ops
176+
* ring_id: the ring_id to be inited
177+
*/
178+
std::string &endpoint = config_.current_endpoint;
179+
std::stringstream ss;
180+
ss << "Init comm with tmp var: " << tmp_var_name
181+
<< ". The ring id is: " << ring_id << ". The group has: " << nranks
182+
<< " ranks. Current rank in the group is: " << rank
183+
<< ". The endpoint is: " << endpoint << ". Peer endpoints are: ";
184+
for (auto ep : peer_endpoints) {
185+
ss << ep << ", ";
186+
}
187+
VLOG(3) << ss.str();
188+
if (config_.place == "GPU") {
189+
framework::VarDesc *new_var = block->Var(tmp_var_name);
190+
new_var->SetType(framework::proto::VarType::RAW);
191+
new_var->SetPersistable(true);
192+
framework::OpDesc *gen_nccl_id_op = block->AppendOp();
193+
gen_nccl_id_op->SetType("c_gen_nccl_id");
194+
gen_nccl_id_op->SetOutput("Out", {tmp_var_name});
195+
gen_nccl_id_op->SetAttr("rank", rank);
196+
gen_nccl_id_op->SetAttr("endpoint", config_.current_endpoint);
197+
gen_nccl_id_op->SetAttr("other_endpoints", peer_endpoints);
198+
gen_nccl_id_op->SetAttr("ring_id", ring_id);
199+
gen_nccl_id_op->SetAttr("op_role",
200+
static_cast<int>(framework::OpRole::kForward));
201+
gen_nccl_id_op->CheckAttrs();
202+
framework::OpDesc *comm_init_op = block->AppendOp();
203+
comm_init_op->SetType("c_comm_init");
204+
comm_init_op->SetInput("X", {tmp_var_name});
205+
comm_init_op->SetAttr("rank", rank);
206+
comm_init_op->SetAttr("nranks", nranks);
207+
comm_init_op->SetAttr("ring_id", ring_id);
208+
comm_init_op->SetAttr("op_role",
209+
static_cast<int>(framework::OpRole::kForward));
210+
comm_init_op->CheckAttrs();
211+
} else {
212+
LOG(WARNING) << "DistModelInf doesn't init comm.";
213+
// TODO(fleet exe dev): comm init for more devices
214+
}
215+
}
216+
58217
bool DistModel::PrepareScope() {
59218
scope_.reset(new framework::Scope());
60219
return true;
@@ -119,6 +278,8 @@ bool DistModel::LoadParameters() {
119278
new_var->SetLoDLevel(var->GetLoDLevel());
120279
new_var->SetPersistable(true);
121280
params.push_back(new_var->Name());
281+
// NOTE: if the params are stored in different files, 'load' op should be
282+
// added here
122283
}
123284
}
124285

@@ -145,6 +306,55 @@ bool DistModel::LoadParameters() {
145306
return true;
146307
}
147308

309+
bool DistModel::PrepareFleetExe() {
310+
task_node_.reset(new TaskNode(program_.get(), config_.local_rank));
311+
if (config_.local_rank - config_.mp_degree >= 0) {
312+
task_node_->AddUpstreamTask(config_.local_rank - config_.mp_degree);
313+
}
314+
if (config_.local_rank + config_.mp_degree < config_.nranks) {
315+
task_node_->AddDownstreamTask(config_.local_rank + config_.mp_degree);
316+
}
317+
task_node_->SetType("Compute");
318+
task_node_->Init();
319+
executor_desc_ = FleetExecutorDesc();
320+
executor_desc_.set_cur_rank(config_.local_rank);
321+
std::unordered_map<int64_t, int64_t> id_to_rank;
322+
for (int i = 0; i < config_.nranks; ++i) {
323+
RankInfo *rank_info = executor_desc_.add_cluster_info();
324+
rank_info->set_rank(i);
325+
rank_info->set_ip_port(config_.trainer_endpoints[i]);
326+
id_to_rank.insert({i, i});
327+
}
328+
fleet_exe.reset(new FleetExecutor(executor_desc_));
329+
fleet_exe->Init("inference", *(program_.get()), scope_.get(), place_, 1,
330+
{task_node_.get()}, id_to_rank);
331+
return true;
332+
}
333+
334+
bool DistModel::PrepareFeedAndFetch() {
335+
for (auto *op : program_->Block(0).AllOps()) {
336+
if (op->Type() == "feed") {
337+
VLOG(3) << "feed op with feed var: " << op->Output("Out")[0];
338+
int idx = BOOST_GET_CONST(int, op->GetAttr("col"));
339+
if (feeds_.size() <= static_cast<size_t>(idx)) {
340+
feeds_.resize(idx + 1);
341+
}
342+
feeds_[idx] = op;
343+
feed_names_[op->Output("Out")[0]] = idx;
344+
idx_to_feeds_[idx] = op->Output("Out")[0];
345+
} else if (op->Type() == "fetch") {
346+
VLOG(3) << "fetch op with fetch var: " << op->Input("X")[0];
347+
int idx = BOOST_GET_CONST(int, op->GetAttr("col"));
348+
if (fetches_.size() <= static_cast<size_t>(idx)) {
349+
fetches_.resize(idx + 1);
350+
}
351+
fetches_[idx] = op;
352+
id_to_fetches_[idx] = op->Input("X")[0];
353+
}
354+
}
355+
return true;
356+
}
357+
148358
void DistModel::Run(const std::vector<paddle::framework::Tensor> &input_data,
149359
std::vector<paddle::framework::Tensor> *output_data) {
150360
/* TODO(fleet exe dev): implement this funct */

paddle/fluid/distributed/fleet_executor/dist_model.h

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,22 +23,33 @@
2323
#include "paddle/fluid/platform/place.h"
2424

2525
namespace paddle {
26+
2627
namespace framework {
2728
class ProgramDesc;
2829
class Scope;
30+
class BlockDesc;
2931
}
3032

3133
namespace distributed {
3234

35+
class TaskNode;
36+
class FleetExecutor;
37+
3338
struct DistModelConfig {
3439
std::string model_dir{};
40+
framework::ProgramDesc* program_desc{nullptr};
41+
framework::Scope* scope{nullptr};
42+
std::string place{};
43+
int64_t device_id{0};
3544
std::vector<std::string> trainer_endpoints{};
3645
std::string current_endpoint{};
3746
int64_t nranks{1};
3847
int64_t local_rank{0};
39-
int64_t device_id{0};
4048
int64_t mp_degree{1};
4149
int64_t pp_degree{1};
50+
int64_t mp_ring_id{-1};
51+
int64_t pp_upstream_ring_id{-1};
52+
int64_t pp_downstream_ring_id{-1};
4253
};
4354

4455
class DistModel {
@@ -56,12 +67,25 @@ class DistModel {
5667
bool PrepareProgram();
5768
bool LoadProgram();
5869
bool LoadParameters();
70+
bool PreparePlace();
5971
bool CommInit();
72+
bool PrepareFeedAndFetch();
73+
bool PrepareFleetExe();
74+
void InsertCommOp(std::string tmp_var_name, int nranks, int rank,
75+
const std::vector<std::string>& peer_endpoints,
76+
framework::BlockDesc* block, int ring_id);
6077

78+
std::vector<framework::OpDesc*> feeds_;
79+
std::map<std::string, int64_t> feed_names_;
80+
std::map<int64_t, std::string> idx_to_feeds_;
81+
std::vector<framework::OpDesc*> fetches_;
82+
std::map<int64_t, std::string> id_to_fetches_;
6183
DistModelConfig config_;
6284
FleetExecutorDesc executor_desc_;
63-
platform::Place place_;
85+
std::shared_ptr<FleetExecutor> fleet_exe;
86+
std::shared_ptr<TaskNode> task_node_;
6487
std::shared_ptr<framework::Scope> scope_;
88+
paddle::platform::Place place_;
6589
std::shared_ptr<framework::ProgramDesc> program_;
6690
};
6791

paddle/fluid/distributed/fleet_executor/fleet_executor.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,13 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {
3535
InitMessageBus();
3636
}
3737

38+
FleetExecutor::FleetExecutor(const FleetExecutorDesc& exe_desc)
39+
: exe_desc_(exe_desc) {
40+
// Message bus will be created and inited only once
41+
GlobalVal<MessageBus>::Create();
42+
InitMessageBus();
43+
}
44+
3845
FleetExecutor::~FleetExecutor() {
3946
for (const auto& carrier_id : carrier_ids_) {
4047
GlobalMap<std::string, Carrier>::Get(carrier_id)->Release();

paddle/fluid/distributed/fleet_executor/fleet_executor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class FleetExecutor final {
3636
public:
3737
FleetExecutor() = delete;
3838
explicit FleetExecutor(const std::string& exe_desc_str);
39+
explicit FleetExecutor(const FleetExecutorDesc& exe_desc);
3940
~FleetExecutor();
4041
void Init(const std::string& carrier_id,
4142
const framework::ProgramDesc& program_desc, framework::Scope* scope,

paddle/fluid/distributed/fleet_executor/task_node.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,16 @@ TaskNode::TaskNode(paddle::framework::ProgramDesc* program, int64_t rank,
3838
task_id_ = task_node_cnt++;
3939
}
4040

41+
TaskNode::TaskNode(paddle::framework::ProgramDesc* program, int64_t rank)
42+
: program_(program), rank_(rank), task_id_(rank) {
43+
max_run_times_ = 1;
44+
max_slot_nums_ = 1;
45+
LOG(INFO)
46+
<< "Constructing TaskNode for DistModelInf. The TaskNode's id is: "
47+
<< rank
48+
<< ". And the TaskNode's max_run_time and max_slot_num will be set to 1.";
49+
}
50+
4151
void TaskNode::SetProgram(paddle::framework::ProgramDesc* program) {
4252
program_ = program;
4353
}

paddle/fluid/distributed/fleet_executor/task_node.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class TaskNode final {
4242
int64_t max_slot_nums);
4343
TaskNode(paddle::framework::ProgramDesc* program, int64_t rank,
4444
int64_t max_run_times, int64_t max_slot_nums);
45+
TaskNode(paddle::framework::ProgramDesc* program, int64_t rank);
4546
~TaskNode() = default;
4647

4748
void SetProgram(paddle::framework::ProgramDesc* program);

0 commit comments

Comments
 (0)