Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 95 additions & 50 deletions paddle/fluid/framework/details/multi_devices_graph_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
for (auto &p : params) {
grad_names_.insert(GradVarName(p));
}
balance_vars_.resize(places_.size(), 0);
}

void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
Expand Down Expand Up @@ -140,11 +141,30 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(
checker(op.InputArgumentNames(), recv_vars);
}

size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
const std::vector<std::string> &var_names) const {
int64_t numel_sum = 0;
for (auto var_name : var_names) {
auto var_desc = all_vars_.at(var_name);
PADDLE_ENFORCE_NOT_NULL(var_desc);
auto dim = framework::make_ddim(var_desc->GetShape());
int64_t numel = framework::product(dim);
PADDLE_ENFORCE_GT(numel, 0);
numel_sum += numel;
}

auto smallest =
std::min_element(std::begin(balance_vars_), std::end(balance_vars_));
size_t dev_id =
static_cast<size_t>(std::distance(std::begin(balance_vars_), smallest));
balance_vars_[dev_id] += numel_sum;
return dev_id;
}

std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After this change. Build() can only be called once? Do we want to clear "balanced_vars", "all_vars", etc at the beginning of Build()?

const ProgramDesc &program) const {
std::unordered_map<std::string, VarDesc *> all_vars;
for (auto *var : program.Block(0).AllVars()) {
all_vars[var->Name()] = var;
all_vars_.emplace(var->Name(), var);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

emplace has difference semantics from []? If not necessary, let's it keep it the same.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not too much, just can avoid some non-necessary copy, but it's no difference here.

}

auto graph = new SSAGraph();
Expand All @@ -161,35 +181,16 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
auto send_vars = FindDistTrainSendVars(program);
auto recv_vars = FindDistTrainRecvVars(program);

std::vector<std::unordered_set<std::string>> var_name_on_devices;
std::vector<std::unordered_set<std::string>> bcast_var_name_set;
var_name_on_devices.resize(places_.size());
bcast_var_name_set.resize(places_.size());

size_t cur_device_id = 0;
std::vector<int64_t> balance_grads(places_.size(), 0);

auto get_appropriate_dev = [&](std::string &g_name) -> size_t {
auto var_desc = all_vars.at(g_name);
PADDLE_ENFORCE_NOT_NULL(var_desc);
auto dim = framework::make_ddim(var_desc->GetShape());
int64_t numel = framework::product(dim);
PADDLE_ENFORCE_GE(numel, 0);
auto smallest =
std::min_element(std::begin(balance_grads), std::end(balance_grads));
size_t dev_id =
static_cast<size_t>(std::distance(std::begin(balance_grads), smallest));
balance_grads[dev_id] += numel;
return dev_id;
};

bool is_forwarding = true;

for (auto *op : program.Block(0).AllOps()) {
if (boost::get<int>(
op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kRPC)) {
// append rpc op if program is distributed trainer main program.
// always use the first device
CreateRPCOp(&result, *op);
} else if (IsDistTrainOp(*op, send_vars, recv_vars)) {
CreateDistTrainOp(&result, *op);
Expand All @@ -201,13 +202,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
}
is_forwarding = false;
} else {
int op_dev_id = GetOpDeviceID(var_name_on_devices, *op);
int op_dev_id = GetOpDeviceID(*op);
if (op_dev_id == -1) { // var on all device
CreateComputationalOps(&result, *op, places_.size());
} else {
CreateComputationalOp(&result, *op, op_dev_id);
for (auto &var_name : op->OutputArgumentNames()) {
var_name_on_devices[op_dev_id].emplace(var_name);
var_name_on_devices_.emplace(var_name, op_dev_id);
}
}
if (!is_forwarding && places_.size() > 1) {
Expand All @@ -230,13 +231,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(

switch (strategy_.reduce_) {
case BuildStrategy::ReduceStrategy::kReduce:
cur_device_id = get_appropriate_dev(g_name);
cur_device_id = GetAppropriateDeviceID({g_name});
CreateReduceOp(&result, g_name, cur_device_id);
var_name_on_devices[cur_device_id].emplace(g_name);
var_name_on_devices_.emplace(g_name, cur_device_id);
bcast_var_name_set[cur_device_id].emplace(p_name);
break;
case BuildStrategy::ReduceStrategy::kAllReduce:
if (IsSparseGradient(all_vars, g_name)) {
if (IsSparseGradient(g_name)) {
CreateReduceOp(&result, g_name, 0);
CreateBroadcastOp(&result, g_name, 0);
} else {
Expand Down Expand Up @@ -273,11 +274,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
return std::unique_ptr<SSAGraph>(graph);
}

bool MultiDevSSAGraphBuilder::IsSparseGradient(
const std::unordered_map<std::string, VarDesc *> &all_vars,
const std::string &og) const {
PADDLE_ENFORCE(all_vars.count(og) != 0);
if (all_vars.at(og)->GetType() == proto::VarType::SELECTED_ROWS) {
bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const {
PADDLE_ENFORCE(all_vars_.count(og) != 0);
if (all_vars_.at(og)->GetType() == proto::VarType::SELECTED_ROWS) {
return true;
}
return false;
Expand Down Expand Up @@ -363,24 +362,23 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce(
return is_pg_once;
}

int MultiDevSSAGraphBuilder::GetOpDeviceID(
const std::vector<std::unordered_set<std::string>> &var_name_on_devices,
const OpDesc &op) const {
int MultiDevSSAGraphBuilder::GetOpDeviceID(const OpDesc &op) const {
if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) {
return -1;
}

int var_dev_id = -1;
for (auto &var_name : op.InputArgumentNames()) {
if (var_dev_id != -1) break;
for (size_t i = 0; i < var_name_on_devices.size(); ++i) {
if (var_name_on_devices[i].count(var_name)) {
var_dev_id = static_cast<int>(i);
break;
}
for (auto &varname : op.InputArgumentNames()) {
int dev_id = GetVarDeviceID(varname);
if (dev_id != -1) {
return dev_id;
}
}
return var_dev_id;
return -1;
}

int MultiDevSSAGraphBuilder::GetVarDeviceID(const std::string &varname) const {
auto got = var_name_on_devices_.find(varname);
return got == var_name_on_devices_.end() ? -1 : got->second;
}

void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const {
Expand Down Expand Up @@ -463,16 +461,65 @@ void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op,

void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
const OpDesc &op) const {
CreateComputationalOp(result, op, 0);
int op_dev_id = -1;
if (op.Type() == "split_byref") {
op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]);
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
op_dev_id = GetAppropriateDeviceID(op.InputArgumentNames());
for (auto &varname : op.InputArgumentNames()) {
var_name_on_devices_.emplace(varname, op_dev_id);
}
}
for (auto &varname : op.OutputArgumentNames()) {
var_name_on_devices_.emplace(varname, op_dev_id);
}
} else if (op.Type() == "concat") {
op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]);
} else {
PADDLE_ENFORCE(
"the distribute training related op should be in [split_byref, "
"concat].");
}

PADDLE_ENFORCE(op_dev_id != -1,
"can not find right place for distributed op: %s", op.Type());

CreateComputationalOp(result, op, op_dev_id);
if (op.Type() == "concat") {
ConnectOp(result, result->ops_.back().get(), "fetch_barrier");
}
}

void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
const OpDesc &op) const {
result->ops_.emplace_back(
new RPCOpHandle(op, local_scopes_[0], op.Type(), places_[0]));
int op_dev_id = -1;
if (op.Type() == "send") {
op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]);
// the variable name which contains .block means it was splited by
// split_byref op
// so that we can balance the variable blocks to all the pserver instances.
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce &&
op.InputArgumentNames()[0].find(".block") == std::string::npos) {
op_dev_id = GetAppropriateDeviceID(op.InputArgumentNames());
for (auto &varname : op.InputArgumentNames()) {
var_name_on_devices_.emplace(varname, op_dev_id);
}
}
} else if (op.Type() == "recv") {
op_dev_id = GetAppropriateDeviceID(op.OutputArgumentNames());
for (auto &varname : op.OutputArgumentNames()) {
var_name_on_devices_.emplace(varname, op_dev_id);
}
} else {
// send_barrier and fetch_barrier op can be scheduled on device 0
op_dev_id = 0;
}

PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s",
op.Type());

result->ops_.emplace_back(new RPCOpHandle(op, local_scopes_[op_dev_id],
op.Type(), places_[op_dev_id]));

if (op.Type() == "send_barrier") {
ConnectOp(result, result->ops_.back().get(), "send");
Expand All @@ -488,9 +535,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
"send, send_barrier. recv, fetch_barrier]");
}

// TODO(Yancey1989): schedule rpc op on different place may
// increate throughput
CreateOpHandleIOs(result, op, 0);
CreateOpHandleIOs(result, op, op_dev_id);
}

bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const {
Expand Down
17 changes: 10 additions & 7 deletions paddle/fluid/framework/details/multi_devices_graph_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,11 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
#endif

std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override;
int GetVarDeviceID(const std::string &varname) const;

private:
void CreateOpHandleIOs(SSAGraph *result, const OpDesc &op,
size_t place_id) const;
size_t device_id) const;

private:
std::string loss_var_name_;
Expand Down Expand Up @@ -96,21 +97,23 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const std::string &og,
std::unordered_set<std::string> *og_has_been_broadcast) const;

int GetOpDeviceID(
const std::vector<std::unordered_set<std::string>> &var_name_on_devices,
const OpDesc &op) const;
int GetOpDeviceID(const OpDesc &op) const;

void InsertAllReduceOp(SSAGraph *result, const std::string &og) const;

void CreateBroadcastOp(SSAGraph *result, const std::string &p_name,
size_t src_dev_id) const;

bool IsSparseGradient(
const std::unordered_map<std::string, VarDesc *> &all_vars,
const std::string &og) const;
bool IsSparseGradient(const std::string &og) const;

size_t GetAppropriateDeviceID(
const std::vector<std::string> &var_names) const;

private:
BuildStrategy strategy_;
mutable std::unordered_map<std::string, VarDesc *> all_vars_;
mutable std::unordered_map<std::string, int> var_name_on_devices_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not use unordered_map to record the var_name on devices, because the same var_name may be on different devices.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May not, this does not record all variables, only used for Reduce strategy and distributed training.

For the Reduce strategy, we schedule Reduce Op on the different device and record the gradient variable name in var_name_on_devices_ , so it would only appear on only one device.

For the distributed training, the same as Reduce strategy, we schedule send_op and recv_op on the different device, the variable name would not appear on the different device also.

mutable std::vector<int64_t> balance_vars_;

void SetCommunicationContext(OpHandleBase *op_handle,
const platform::Place &p) const;
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/details/ssa_graph_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class SSAGraphBuilder {
SSAGraphBuilder() {}
virtual ~SSAGraphBuilder() {}
virtual std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const = 0;
virtual int GetVarDeviceID(const std::string &var_name) const { return -1; }

DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder);

Expand Down
23 changes: 18 additions & 5 deletions paddle/fluid/framework/parallel_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ ParallelExecutor::ParallelExecutor(

// Step 3. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp

details::SSAGraphBuilderFactory builder_factory(
member_->places_, loss_var_name, params, member_->local_scopes_,
build_strategy);
Expand All @@ -122,9 +121,10 @@ ParallelExecutor::ParallelExecutor(
#endif
}

builder_ = std::move(builder_factory.Create());
member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, places,
builder_factory.Create()->Build(main_program)));
builder_->Build(main_program)));

member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, std::move(var_infos),
Expand All @@ -133,10 +133,22 @@ ParallelExecutor::ParallelExecutor(

void ParallelExecutor::BCastParamsToGPUs(
const std::unordered_set<std::string> &vars) const {
auto *main_scope = member_->local_scopes_[0];
// the the initialize bcast, all vars would be bcast from device(0), otherwise
// bcast from the specified device.
bool initialize = builder_.get() == nullptr ? true : false;

for (auto &var : vars) {
auto *main_var = main_scope->FindVar(var);
int var_dev_id =
builder_.get() == nullptr ? -1 : builder_->GetVarDeviceID(var);
if (!initialize && var_dev_id == -1) continue;

framework::Variable *main_var = nullptr;
if (initialize) {
main_var = member_->local_scopes_[0]->FindVar(var);
} else {
main_var = member_->local_scopes_[var_dev_id]->FindVar(var);
}

if (main_var == nullptr || !main_var->IsType<LoDTensor>()) {
continue;
}
Expand All @@ -151,7 +163,8 @@ void ParallelExecutor::BCastParamsToGPUs(
for (size_t i = 0; i < member_->places_.size(); ++i) {
auto place = member_->places_[i];
void *buffer;
if (i == 0) {

if ((initialize && i == 0) || (!initialize && i == var_dev_id)) {
buffer = const_cast<void *>(main_tensor.data<void>());
} else {
auto local_scope = member_->local_scopes_[i];
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/framework/parallel_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ limitations under the License. */
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/execution_strategy.h"
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"

namespace paddle {
namespace framework {

Expand Down Expand Up @@ -68,6 +70,7 @@ class ParallelExecutor {

private:
ParallelExecutorPrivate *member_;
std::unique_ptr<details::SSAGraphBuilder> builder_;
};

} // namespace framework
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/operators/tensorrt_engine_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@

#ifdef PADDLE_WITH_CUDA

#include <string>
#include <vector>

#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
Expand Down
1 change: 0 additions & 1 deletion paddle/fluid/operators/tensorrt_engine_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,6 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) {
const std::string& z_name, bool x_created,
const shape_t& x_shape, const shape_t& y_shape,
const shape_t& z_shape) {

LOG(INFO) << "create fc op";
auto* fc = block_desc.AppendOp();
fc->SetType("mul");
Expand Down