Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ struct CommContext {
const std::vector<std::string> &origin_names, int id,
bool merge_add_ = true, bool is_sparse_ = true,
bool is_distributed_ = false, int table_id_ = -1,
bool is_tensor_table_ = false)
bool is_tensor_table_ = false, bool is_datanorm_table_ = false,
int64_t program_id_ = -1)
: var_name(name),
splited_varnames(names),
epmap(emap),
Expand All @@ -42,7 +43,9 @@ struct CommContext {
is_sparse(is_sparse_),
is_distributed(is_distributed_),
table_id(table_id_),
is_tensor_table(is_tensor_table_) {}
program_id(program_id_),
is_tensor_table(is_tensor_table_),
is_datanorm_table(is_datanorm_table_) {}

CommContext(const CommContext &ctx) {
var_name = ctx.var_name;
Expand All @@ -55,7 +58,9 @@ struct CommContext {
origin_varnames = ctx.origin_varnames;
is_distributed = ctx.is_distributed;
table_id = ctx.table_id;
program_id = ctx.program_id;
is_tensor_table = ctx.is_tensor_table;
is_datanorm_table = ctx.is_datanorm_table;
}

std::string print() const {
Expand All @@ -78,7 +83,9 @@ struct CommContext {
ss << " is_sparse: " << is_sparse;
ss << " is_distributed: " << is_distributed << "\n";
ss << " table_id: " << table_id << "\n";
ss << " program_id: " << program_id << "\n";
ss << " is_tensor_table: " << is_tensor_table << "\n";
ss << " is_datanorm_table: " << is_datanorm_table << "\n";

return ss.str();
}
Expand All @@ -93,7 +100,9 @@ struct CommContext {
bool is_sparse;
bool is_distributed;
int table_id;
int64_t program_id;
bool is_tensor_table;
bool is_datanorm_table;
};

} // namespace distributed
Expand Down
6 changes: 5 additions & 1 deletion paddle/fluid/pybind/fleet_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,13 @@ void BindCommunicatorContext(py::module* m) {
py::init<const std::string&, const std::vector<std::string>&,
const std::vector<std::string>&, const std::vector<int64_t>&,
const std::vector<std::string>&, int, bool, bool, bool, int,
bool>())
bool, bool, int64_t>())
.def("var_name", [](const CommContext& self) { return self.var_name; })
.def("trainer_id",
[](const CommContext& self) { return self.trainer_id; })
.def("table_id", [](const CommContext& self) { return self.table_id; })
.def("program_id",
[](const CommContext& self) { return self.program_id; })
.def("split_varnames",
[](const CommContext& self) { return self.splited_varnames; })
.def("split_endpoints",
Expand All @@ -122,6 +124,8 @@ void BindCommunicatorContext(py::module* m) {
[](const CommContext& self) { return self.origin_varnames; })
.def("is_tensor_table",
[](const CommContext& self) { return self.is_tensor_table; })
.def("is_datanorm_table",
[](const CommContext& self) { return self.is_datanorm_table; })
.def("__str__", [](const CommContext& self) { return self.print(); });
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def _init_ps_pass_context(self, loss, startup_program):
attrs['loss'] = loss
attrs['min_block_size'] = 81920
attrs['origin_main_program'] = loss.block.program
attrs['origin_main_programs'] = [loss.block.program]
attrs['origin_startup_program'] = startup_program
attrs['origin_startup_programs'] = [startup_program]

attrs['cloned_main'] = attrs['origin_main_program'].clone()
attrs['cloned_startup'] = attrs['origin_startup_program'].clone()
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/distributed/passes/ps_trainer_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,9 +560,9 @@ def _check_conflict(self, other_pass):
return True

def _get_sparse_table_names(self, attrs):
dist_varnames = get_sparse_tablenames(attrs['origin_main_program'],
dist_varnames = get_sparse_tablenames(attrs['origin_main_programs'],
True)
sparse_varnames = get_sparse_tablenames(attrs['origin_main_program'],
sparse_varnames = get_sparse_tablenames(attrs['origin_main_programs'],
False)
return list(set(dist_varnames + sparse_varnames))

Expand Down
Loading