-
Notifications
You must be signed in to change notification settings - Fork 5.9k
overlap rpc op memcpy in distributed training #11221
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
93401c9
6d69ae0
82d741c
cb38615
e533a4b
15913d9
23433de
d5a88b9
4444e79
6d752ba
f52d78d
3d875b6
7d1b146
7e6518e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -57,6 +57,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( | |
| for (auto &p : params) { | ||
| grad_names_.insert(GradVarName(p)); | ||
| } | ||
| balance_vars_.resize(places_.size(), 0); | ||
| } | ||
|
|
||
| void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, | ||
|
|
@@ -140,11 +141,30 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp( | |
| checker(op.InputArgumentNames(), recv_vars); | ||
| } | ||
|
|
||
| size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID( | ||
| const std::vector<std::string> &var_names) const { | ||
| int64_t numel_sum = 0; | ||
| for (auto var_name : var_names) { | ||
| auto var_desc = all_vars_.at(var_name); | ||
| PADDLE_ENFORCE_NOT_NULL(var_desc); | ||
| auto dim = framework::make_ddim(var_desc->GetShape()); | ||
| int64_t numel = framework::product(dim); | ||
| PADDLE_ENFORCE_GT(numel, 0); | ||
| numel_sum += numel; | ||
| } | ||
|
|
||
| auto smallest = | ||
| std::min_element(std::begin(balance_vars_), std::end(balance_vars_)); | ||
| size_t dev_id = | ||
| static_cast<size_t>(std::distance(std::begin(balance_vars_), smallest)); | ||
| balance_vars_[dev_id] += numel_sum; | ||
| return dev_id; | ||
| } | ||
|
|
||
| std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( | ||
| const ProgramDesc &program) const { | ||
| std::unordered_map<std::string, VarDesc *> all_vars; | ||
| for (auto *var : program.Block(0).AllVars()) { | ||
| all_vars[var->Name()] = var; | ||
| all_vars_.emplace(var->Name(), var); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not too much, just can avoid some non-necessary copy, but it's no difference here. |
||
| } | ||
|
|
||
| auto graph = new SSAGraph(); | ||
|
|
@@ -161,35 +181,16 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( | |
| auto send_vars = FindDistTrainSendVars(program); | ||
| auto recv_vars = FindDistTrainRecvVars(program); | ||
|
|
||
| std::vector<std::unordered_set<std::string>> var_name_on_devices; | ||
| std::vector<std::unordered_set<std::string>> bcast_var_name_set; | ||
| var_name_on_devices.resize(places_.size()); | ||
| bcast_var_name_set.resize(places_.size()); | ||
|
|
||
| size_t cur_device_id = 0; | ||
| std::vector<int64_t> balance_grads(places_.size(), 0); | ||
|
|
||
| auto get_appropriate_dev = [&](std::string &g_name) -> size_t { | ||
| auto var_desc = all_vars.at(g_name); | ||
| PADDLE_ENFORCE_NOT_NULL(var_desc); | ||
| auto dim = framework::make_ddim(var_desc->GetShape()); | ||
| int64_t numel = framework::product(dim); | ||
| PADDLE_ENFORCE_GE(numel, 0); | ||
| auto smallest = | ||
| std::min_element(std::begin(balance_grads), std::end(balance_grads)); | ||
| size_t dev_id = | ||
| static_cast<size_t>(std::distance(std::begin(balance_grads), smallest)); | ||
| balance_grads[dev_id] += numel; | ||
| return dev_id; | ||
| }; | ||
|
|
||
| bool is_forwarding = true; | ||
|
|
||
| for (auto *op : program.Block(0).AllOps()) { | ||
| if (boost::get<int>( | ||
| op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == | ||
| static_cast<int>(OpRole::kRPC)) { | ||
| // append rpc op if program is distributed trainer main program. | ||
| // always use the first device | ||
| CreateRPCOp(&result, *op); | ||
| } else if (IsDistTrainOp(*op, send_vars, recv_vars)) { | ||
| CreateDistTrainOp(&result, *op); | ||
|
|
@@ -201,13 +202,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( | |
| } | ||
| is_forwarding = false; | ||
| } else { | ||
| int op_dev_id = GetOpDeviceID(var_name_on_devices, *op); | ||
| int op_dev_id = GetOpDeviceID(*op); | ||
| if (op_dev_id == -1) { // var on all device | ||
| CreateComputationalOps(&result, *op, places_.size()); | ||
| } else { | ||
| CreateComputationalOp(&result, *op, op_dev_id); | ||
| for (auto &var_name : op->OutputArgumentNames()) { | ||
| var_name_on_devices[op_dev_id].emplace(var_name); | ||
| var_name_on_devices_.emplace(var_name, op_dev_id); | ||
| } | ||
| } | ||
| if (!is_forwarding && places_.size() > 1) { | ||
|
|
@@ -230,13 +231,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( | |
|
|
||
| switch (strategy_.reduce_) { | ||
| case BuildStrategy::ReduceStrategy::kReduce: | ||
| cur_device_id = get_appropriate_dev(g_name); | ||
| cur_device_id = GetAppropriateDeviceID({g_name}); | ||
| CreateReduceOp(&result, g_name, cur_device_id); | ||
| var_name_on_devices[cur_device_id].emplace(g_name); | ||
| var_name_on_devices_.emplace(g_name, cur_device_id); | ||
| bcast_var_name_set[cur_device_id].emplace(p_name); | ||
| break; | ||
| case BuildStrategy::ReduceStrategy::kAllReduce: | ||
| if (IsSparseGradient(all_vars, g_name)) { | ||
| if (IsSparseGradient(g_name)) { | ||
| CreateReduceOp(&result, g_name, 0); | ||
| CreateBroadcastOp(&result, g_name, 0); | ||
| } else { | ||
|
|
@@ -273,11 +274,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( | |
| return std::unique_ptr<SSAGraph>(graph); | ||
| } | ||
|
|
||
| bool MultiDevSSAGraphBuilder::IsSparseGradient( | ||
| const std::unordered_map<std::string, VarDesc *> &all_vars, | ||
| const std::string &og) const { | ||
| PADDLE_ENFORCE(all_vars.count(og) != 0); | ||
| if (all_vars.at(og)->GetType() == proto::VarType::SELECTED_ROWS) { | ||
| bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const { | ||
| PADDLE_ENFORCE(all_vars_.count(og) != 0); | ||
| if (all_vars_.at(og)->GetType() == proto::VarType::SELECTED_ROWS) { | ||
| return true; | ||
| } | ||
| return false; | ||
|
|
@@ -363,24 +362,23 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce( | |
| return is_pg_once; | ||
| } | ||
|
|
||
| int MultiDevSSAGraphBuilder::GetOpDeviceID( | ||
| const std::vector<std::unordered_set<std::string>> &var_name_on_devices, | ||
| const OpDesc &op) const { | ||
| int MultiDevSSAGraphBuilder::GetOpDeviceID(const OpDesc &op) const { | ||
| if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) { | ||
| return -1; | ||
| } | ||
|
|
||
| int var_dev_id = -1; | ||
| for (auto &var_name : op.InputArgumentNames()) { | ||
| if (var_dev_id != -1) break; | ||
| for (size_t i = 0; i < var_name_on_devices.size(); ++i) { | ||
| if (var_name_on_devices[i].count(var_name)) { | ||
| var_dev_id = static_cast<int>(i); | ||
| break; | ||
| } | ||
| for (auto &varname : op.InputArgumentNames()) { | ||
| int dev_id = GetVarDeviceID(varname); | ||
| if (dev_id != -1) { | ||
| return dev_id; | ||
| } | ||
| } | ||
| return var_dev_id; | ||
| return -1; | ||
| } | ||
|
|
||
| int MultiDevSSAGraphBuilder::GetVarDeviceID(const std::string &varname) const { | ||
| auto got = var_name_on_devices_.find(varname); | ||
| return got == var_name_on_devices_.end() ? -1 : got->second; | ||
| } | ||
|
|
||
| void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const { | ||
|
|
@@ -463,16 +461,65 @@ void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op, | |
|
|
||
| void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result, | ||
| const OpDesc &op) const { | ||
| CreateComputationalOp(result, op, 0); | ||
| int op_dev_id = -1; | ||
| if (op.Type() == "split_byref") { | ||
| op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]); | ||
| if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { | ||
| op_dev_id = GetAppropriateDeviceID(op.InputArgumentNames()); | ||
| for (auto &varname : op.InputArgumentNames()) { | ||
| var_name_on_devices_.emplace(varname, op_dev_id); | ||
| } | ||
| } | ||
| for (auto &varname : op.OutputArgumentNames()) { | ||
| var_name_on_devices_.emplace(varname, op_dev_id); | ||
| } | ||
| } else if (op.Type() == "concat") { | ||
| op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]); | ||
| } else { | ||
| PADDLE_ENFORCE( | ||
| "the distribute training related op should be in [split_byref, " | ||
| "concat]."); | ||
| } | ||
|
|
||
| PADDLE_ENFORCE(op_dev_id != -1, | ||
| "can not find right place for distributed op: %s", op.Type()); | ||
|
|
||
| CreateComputationalOp(result, op, op_dev_id); | ||
| if (op.Type() == "concat") { | ||
| ConnectOp(result, result->ops_.back().get(), "fetch_barrier"); | ||
| } | ||
| } | ||
|
|
||
| void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, | ||
| const OpDesc &op) const { | ||
| result->ops_.emplace_back( | ||
| new RPCOpHandle(op, local_scopes_[0], op.Type(), places_[0])); | ||
| int op_dev_id = -1; | ||
| if (op.Type() == "send") { | ||
| op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]); | ||
| // the variable name which contains .block means it was splited by | ||
| // split_byref op | ||
| // so that we can balance the variable blocks to all the pserver instances. | ||
| if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce && | ||
| op.InputArgumentNames()[0].find(".block") == std::string::npos) { | ||
| op_dev_id = GetAppropriateDeviceID(op.InputArgumentNames()); | ||
| for (auto &varname : op.InputArgumentNames()) { | ||
| var_name_on_devices_.emplace(varname, op_dev_id); | ||
| } | ||
| } | ||
| } else if (op.Type() == "recv") { | ||
| op_dev_id = GetAppropriateDeviceID(op.OutputArgumentNames()); | ||
| for (auto &varname : op.OutputArgumentNames()) { | ||
| var_name_on_devices_.emplace(varname, op_dev_id); | ||
| } | ||
| } else { | ||
| // send_barrier and fetch_barrier op can be scheduled on device 0 | ||
| op_dev_id = 0; | ||
| } | ||
|
|
||
| PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s", | ||
| op.Type()); | ||
|
|
||
| result->ops_.emplace_back(new RPCOpHandle(op, local_scopes_[op_dev_id], | ||
| op.Type(), places_[op_dev_id])); | ||
|
|
||
| if (op.Type() == "send_barrier") { | ||
| ConnectOp(result, result->ops_.back().get(), "send"); | ||
|
|
@@ -488,9 +535,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, | |
| "send, send_barrier. recv, fetch_barrier]"); | ||
| } | ||
|
|
||
| // TODO(Yancey1989): schedule rpc op on different place may | ||
| // increate throughput | ||
| CreateOpHandleIOs(result, op, 0); | ||
| CreateOpHandleIOs(result, op, op_dev_id); | ||
| } | ||
|
|
||
| bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -47,10 +47,11 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { | |
| #endif | ||
|
|
||
| std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override; | ||
| int GetVarDeviceID(const std::string &varname) const; | ||
|
|
||
| private: | ||
| void CreateOpHandleIOs(SSAGraph *result, const OpDesc &op, | ||
| size_t place_id) const; | ||
| size_t device_id) const; | ||
|
|
||
| private: | ||
| std::string loss_var_name_; | ||
|
|
@@ -96,21 +97,23 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { | |
| const std::string &og, | ||
| std::unordered_set<std::string> *og_has_been_broadcast) const; | ||
|
|
||
| int GetOpDeviceID( | ||
| const std::vector<std::unordered_set<std::string>> &var_name_on_devices, | ||
| const OpDesc &op) const; | ||
| int GetOpDeviceID(const OpDesc &op) const; | ||
|
|
||
| void InsertAllReduceOp(SSAGraph *result, const std::string &og) const; | ||
|
|
||
| void CreateBroadcastOp(SSAGraph *result, const std::string &p_name, | ||
| size_t src_dev_id) const; | ||
|
|
||
| bool IsSparseGradient( | ||
| const std::unordered_map<std::string, VarDesc *> &all_vars, | ||
| const std::string &og) const; | ||
| bool IsSparseGradient(const std::string &og) const; | ||
|
|
||
| size_t GetAppropriateDeviceID( | ||
| const std::vector<std::string> &var_names) const; | ||
|
|
||
| private: | ||
| BuildStrategy strategy_; | ||
| mutable std::unordered_map<std::string, VarDesc *> all_vars_; | ||
| mutable std::unordered_map<std::string, int> var_name_on_devices_; | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should not use unordered_map to record the var_name on devices, because the same var_name may be on different devices. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. May not, this does not record all variables, only used for Reduce strategy and distributed training. For the Reduce strategy, we schedule Reduce Op on the different device and record the gradient variable name in For the distributed training, the same as Reduce strategy, we schedule |
||
| mutable std::vector<int64_t> balance_vars_; | ||
|
|
||
| void SetCommunicationContext(OpHandleBase *op_handle, | ||
| const platform::Place &p) const; | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After this change. Build() can only be called once? Do we want to clear "balanced_vars", "all_vars", etc at the beginning of Build()?