1515#include < glog/logging.h>
1616
1717#include " paddle/fluid/distributed/fleet_executor/dist_model.h"
18+ #include " paddle/fluid/distributed/fleet_executor/fleet_executor.h"
19+ #include " paddle/fluid/distributed/fleet_executor/task_node.h"
20+ #include " paddle/fluid/framework/block_desc.h"
1821#include " paddle/fluid/framework/naive_executor.h"
22+ #include " paddle/fluid/framework/op_proto_maker.h"
1923#include " paddle/fluid/framework/program_desc.h"
2024#include " paddle/fluid/framework/scope.h"
2125#include " paddle/fluid/framework/tensor.h"
@@ -37,24 +41,179 @@ bool IsPersistable(const framework::VarDesc *var) {
3741
3842bool DistModel::Init () {
3943 /* TODO(fleet exe dev): implement this funct */
40- place_ = paddle::platform::CUDAPlace (config_.device_id );
41- if (!PrepareScope ()) {
44+ bool init_method = (!config_.model_dir .empty () || config_.program_desc );
45+ PADDLE_ENFORCE_EQ (init_method, true ,
46+ platform::errors::InvalidArgument (
47+ " One of model dir or program desc must be provided to "
48+ " dist model inference." ));
49+ if (config_.program_desc ) {
50+ PADDLE_ENFORCE_NOT_NULL (
51+ config_.scope , platform::errors::InvalidArgument (
52+ " Scope must be provided to dist model inference if "
53+ " program desc has been provided." ));
54+ }
55+ if (!PreparePlace ()) {
4256 return false ;
4357 }
44- if (!PrepareProgram ()) {
58+ if (!config_.program_desc ) {
59+ if (config_.scope ) {
60+ LOG (WARNING) << " The provided scope will be ignored if model dir has "
61+ " also been provided." ;
62+ }
63+ if (!PrepareScope ()) {
64+ return false ;
65+ }
66+ if (!PrepareProgram ()) {
67+ return false ;
68+ }
69+ } else {
70+ program_.reset (config_.program_desc );
71+ scope_.reset (config_.scope );
72+ }
73+ if (!PrepareFeedAndFetch ()) {
4574 return false ;
4675 }
4776 if (!CommInit ()) {
4877 return false ;
4978 }
79+ if (!PrepareFleetExe ()) {
80+ return false ;
81+ }
82+ return true ;
83+ }
84+
85+ bool DistModel::PreparePlace () {
86+ if (config_.place == " GPU" ) {
87+ place_ = paddle::platform::CUDAPlace (config_.device_id );
88+ } else if (config_.place == " CPU" ) {
89+ place_ = paddle::platform::CPUPlace ();
90+ } else {
91+ PADDLE_THROW (platform::errors::InvalidArgument (
92+ " Place must be choosen from GPU or CPU, but got %s." , config_.place ));
93+ }
5094 return true ;
5195}
5296
5397bool DistModel::CommInit () {
54- // TODO(fleet executor): init the comm
98+ // NOTE (Yuang Liu): The peer endpoints will be obtained with the assumption
99+ // that mp part is always on inner side and pp part is always on outer side.
100+ // TODO(fleet exe dev): The peer endpoints could be configured by users.
101+ PADDLE_ENFORCE_EQ (
102+ config_.pp_degree * config_.mp_degree , config_.nranks ,
103+ platform::errors::InvalidArgument (
104+ " The mp_degree multiplies pp_degree is not equal with nranks" ));
105+ std::unique_ptr<framework::ProgramDesc> comm_init_program (
106+ new framework::ProgramDesc ());
107+ framework::BlockDesc *comm_init_block = comm_init_program->MutableBlock (0 );
108+ if (config_.mp_degree > 1 ) {
109+ PADDLE_ENFORCE_GE (
110+ config_.mp_ring_id , 0 ,
111+ platform::errors::InvalidArgument (
112+ " mp ring id must be provided for inference under mp." ));
113+ VLOG (3 ) << " Init comm group for mp." ;
114+ std::vector<std::string> peer_endpoints;
115+ for (int64_t
116+ idx = (config_.local_rank / config_.mp_degree ) * config_.mp_degree ,
117+ i = 0 ;
118+ i < config_.mp_degree ; ++idx, ++i) {
119+ if (config_.trainer_endpoints [idx] == config_.current_endpoint ) {
120+ continue ;
121+ }
122+ peer_endpoints.emplace_back (config_.trainer_endpoints [idx]);
123+ }
124+ // get nranks in a mp group and inner group rank for local rank
125+ int64_t mp_group_nranks = config_.nranks / config_.pp_degree ;
126+ int64_t mp_group_rank = config_.local_rank % config_.mp_degree ;
127+ InsertCommOp (" mp_comm_id" , mp_group_nranks, mp_group_rank, peer_endpoints,
128+ comm_init_block, config_.mp_ring_id );
129+ }
130+ if (config_.pp_degree ) {
131+ // NOTE: the last pp stage doesn't need init pp comm
132+ VLOG (3 ) << " Init comm group for pp." ;
133+ if (config_.local_rank - config_.mp_degree >= 0 ) {
134+ PADDLE_ENFORCE_EQ (config_.pp_upstream_ring_id >= 0 , true ,
135+ platform::errors::InvalidArgument (
136+ " pp upstream ring id must be provided for "
137+ " non-first pp stage if inference under pp." ));
138+ // not the first pp stage, has upstream
139+ std::vector<std::string> upstream_peer_endpoints;
140+ upstream_peer_endpoints.emplace_back (
141+ config_.trainer_endpoints [config_.local_rank - config_.mp_degree ]);
142+ InsertCommOp (" pp_upstream_comm_id" , 2 , 1 , upstream_peer_endpoints,
143+ comm_init_block, config_.pp_upstream_ring_id );
144+ }
145+
146+ if (config_.local_rank + config_.mp_degree < config_.nranks ) {
147+ PADDLE_ENFORCE_EQ (config_.pp_downstream_ring_id >= 0 , true ,
148+ platform::errors::InvalidArgument (
149+ " pp downstream ring id must be provided for "
150+ " non-last pp stage if inference under pp." ));
151+ // not the last pp stage, has downstream
152+ std::vector<std::string> downstream_peer_endpoints;
153+ downstream_peer_endpoints.emplace_back (
154+ config_.trainer_endpoints [config_.local_rank + config_.mp_degree ]);
155+ InsertCommOp (" pp_downstream_comm_id" , 2 , 0 , downstream_peer_endpoints,
156+ comm_init_block, config_.pp_downstream_ring_id );
157+ }
158+ }
159+ framework::NaiveExecutor e (place_);
160+ e.CreateVariables (*comm_init_program, 0 , true , scope_.get ());
161+ e.Prepare (scope_.get (), *comm_init_program, 0 , false );
162+ e.Run ();
163+ VLOG (3 ) << " Comm init successful." ;
55164 return true ;
56165}
57166
167+ void DistModel::InsertCommOp (std::string tmp_var_name, int nranks, int rank,
168+ const std::vector<std::string> &peer_endpoints,
169+ framework::BlockDesc *block, int ring_id) {
170+ /*
171+ * tmp_var_name: the var name for var comm_id
172+ * nranks: number of total ranks
173+ * rank: the rank of local rank in the comm group
174+ * peer_endpoints: peer's endpoints
175+ * block: the block where to insert the comm ops
176+ * ring_id: the ring_id to be inited
177+ */
178+ std::string &endpoint = config_.current_endpoint ;
179+ std::stringstream ss;
180+ ss << " Init comm with tmp var: " << tmp_var_name
181+ << " . The ring id is: " << ring_id << " . The group has: " << nranks
182+ << " ranks. Current rank in the group is: " << rank
183+ << " . The endpoint is: " << endpoint << " . Peer endpoints are: " ;
184+ for (auto ep : peer_endpoints) {
185+ ss << ep << " , " ;
186+ }
187+ VLOG (3 ) << ss.str ();
188+ if (config_.place == " GPU" ) {
189+ framework::VarDesc *new_var = block->Var (tmp_var_name);
190+ new_var->SetType (framework::proto::VarType::RAW);
191+ new_var->SetPersistable (true );
192+ framework::OpDesc *gen_nccl_id_op = block->AppendOp ();
193+ gen_nccl_id_op->SetType (" c_gen_nccl_id" );
194+ gen_nccl_id_op->SetOutput (" Out" , {tmp_var_name});
195+ gen_nccl_id_op->SetAttr (" rank" , rank);
196+ gen_nccl_id_op->SetAttr (" endpoint" , config_.current_endpoint );
197+ gen_nccl_id_op->SetAttr (" other_endpoints" , peer_endpoints);
198+ gen_nccl_id_op->SetAttr (" ring_id" , ring_id);
199+ gen_nccl_id_op->SetAttr (" op_role" ,
200+ static_cast <int >(framework::OpRole::kForward ));
201+ gen_nccl_id_op->CheckAttrs ();
202+ framework::OpDesc *comm_init_op = block->AppendOp ();
203+ comm_init_op->SetType (" c_comm_init" );
204+ comm_init_op->SetInput (" X" , {tmp_var_name});
205+ comm_init_op->SetAttr (" rank" , rank);
206+ comm_init_op->SetAttr (" nranks" , nranks);
207+ comm_init_op->SetAttr (" ring_id" , ring_id);
208+ comm_init_op->SetAttr (" op_role" ,
209+ static_cast <int >(framework::OpRole::kForward ));
210+ comm_init_op->CheckAttrs ();
211+ } else {
212+ LOG (WARNING) << " DistModelInf doesn't init comm." ;
213+ // TODO(fleet exe dev): comm init for more devices
214+ }
215+ }
216+
58217bool DistModel::PrepareScope () {
59218 scope_.reset (new framework::Scope ());
60219 return true ;
@@ -119,6 +278,8 @@ bool DistModel::LoadParameters() {
119278 new_var->SetLoDLevel (var->GetLoDLevel ());
120279 new_var->SetPersistable (true );
121280 params.push_back (new_var->Name ());
281+ // NOTE: if the params are stored in different files, 'load' op should be
282+ // added here
122283 }
123284 }
124285
@@ -145,6 +306,55 @@ bool DistModel::LoadParameters() {
145306 return true ;
146307}
147308
309+ bool DistModel::PrepareFleetExe () {
310+ task_node_.reset (new TaskNode (program_.get (), config_.local_rank ));
311+ if (config_.local_rank - config_.mp_degree >= 0 ) {
312+ task_node_->AddUpstreamTask (config_.local_rank - config_.mp_degree );
313+ }
314+ if (config_.local_rank + config_.mp_degree < config_.nranks ) {
315+ task_node_->AddDownstreamTask (config_.local_rank + config_.mp_degree );
316+ }
317+ task_node_->SetType (" Compute" );
318+ task_node_->Init ();
319+ executor_desc_ = FleetExecutorDesc ();
320+ executor_desc_.set_cur_rank (config_.local_rank );
321+ std::unordered_map<int64_t , int64_t > id_to_rank;
322+ for (int i = 0 ; i < config_.nranks ; ++i) {
323+ RankInfo *rank_info = executor_desc_.add_cluster_info ();
324+ rank_info->set_rank (i);
325+ rank_info->set_ip_port (config_.trainer_endpoints [i]);
326+ id_to_rank.insert ({i, i});
327+ }
328+ fleet_exe.reset (new FleetExecutor (executor_desc_));
329+ fleet_exe->Init (" inference" , *(program_.get ()), scope_.get (), place_, 1 ,
330+ {task_node_.get ()}, id_to_rank);
331+ return true ;
332+ }
333+
334+ bool DistModel::PrepareFeedAndFetch () {
335+ for (auto *op : program_->Block (0 ).AllOps ()) {
336+ if (op->Type () == " feed" ) {
337+ VLOG (3 ) << " feed op with feed var: " << op->Output (" Out" )[0 ];
338+ int idx = BOOST_GET_CONST (int , op->GetAttr (" col" ));
339+ if (feeds_.size () <= static_cast <size_t >(idx)) {
340+ feeds_.resize (idx + 1 );
341+ }
342+ feeds_[idx] = op;
343+ feed_names_[op->Output (" Out" )[0 ]] = idx;
344+ idx_to_feeds_[idx] = op->Output (" Out" )[0 ];
345+ } else if (op->Type () == " fetch" ) {
346+ VLOG (3 ) << " fetch op with fetch var: " << op->Input (" X" )[0 ];
347+ int idx = BOOST_GET_CONST (int , op->GetAttr (" col" ));
348+ if (fetches_.size () <= static_cast <size_t >(idx)) {
349+ fetches_.resize (idx + 1 );
350+ }
351+ fetches_[idx] = op;
352+ id_to_fetches_[idx] = op->Input (" X" )[0 ];
353+ }
354+ }
355+ return true ;
356+ }
357+
148358void DistModel::Run (const std::vector<paddle::framework::Tensor> &input_data,
149359 std::vector<paddle::framework::Tensor> *output_data) {
150360 /* TODO(fleet exe dev): implement this funct */
0 commit comments