-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Fix dist error with lr decay layer #9489
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
b92aeae
ce2e0a8
c2fcbf7
c8eca6b
1dda42a
1b07d06
633a8b2
05d5e26
b7ffd5d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,6 +24,7 @@ | |
| from regularizer import append_regularization_ops | ||
| from clip import append_gradient_clip_ops, error_clip_callback | ||
| from contextlib import contextmanager | ||
| from distribute_transpiler import UnionFind | ||
|
||
|
|
||
| __all__ = [ | ||
| 'SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad', | ||
|
|
@@ -172,6 +173,42 @@ def _get_accumulator(self, name, param): | |
| format(name, param.name)) | ||
| return self._accumulators[name][param.name] | ||
|
|
||
| def _get_lr_decay_ops(self): | ||
| def __is_op_connected(op1, op2): | ||
| op1_input_names = op1.input_arg_names | ||
| op1_output_names = op1.output_arg_names | ||
|
|
||
| op2_input_names = op2.input_arg_names | ||
| op2_output_names = op2.output_arg_names | ||
|
|
||
| if set(op1_output_names) & set(op2_input_names) or \ | ||
| set(op1_input_names) & set(op2_output_names): | ||
| return True | ||
| return False | ||
|
|
||
| ret_ops = [] | ||
| if isinstance(self._learning_rate, framework.Variable): | ||
| output_op_idx = -1 | ||
| global_block = framework.default_main_program().global_block() | ||
|
|
||
| for idx, op in enumerate(global_block.ops): | ||
| if self._learning_rate.name in op.output_arg_names: | ||
| output_op_idx = idx | ||
| break | ||
| sliced_ops = global_block.slice_ops(0, output_op_idx + 1) | ||
| ufind = UnionFind(sliced_ops) | ||
| for _, op1 in enumerate(sliced_ops): | ||
| for _, op2 in enumerate(sliced_ops): | ||
| if op1 != op2 and __is_op_connected(op1, op2): | ||
| ufind.union(op1, op2) | ||
|
|
||
| for _, op in enumerate(sliced_ops): | ||
| if ufind.is_connected(op, global_block.ops[output_op_idx]): | ||
| ret_ops.append(op) | ||
| ret_ops.append(global_block.ops[output_op_idx]) | ||
|
|
||
| return ret_ops | ||
|
|
||
| def create_optimization_pass(self, | ||
| parameters_and_grads, | ||
| loss, | ||
|
|
@@ -217,9 +254,11 @@ def create_optimization_pass(self, | |
| # Get custom finish ops for subclasses | ||
| # FIXME: Need to fix this once we figure out how to handle dependencies | ||
| self._finish_update(loss.block) | ||
|
|
||
| end = len(global_block.ops) | ||
| return global_block.slice_ops(start, end) | ||
|
|
||
| lr_decay_ops = self._get_lr_decay_ops() | ||
| optimize_ops = global_block.slice_ops(start, end) | ||
| return lr_decay_ops, optimize_ops | ||
|
|
||
| def minimize(self, | ||
| loss, | ||
|
|
@@ -242,9 +281,9 @@ def minimize(self, | |
| params_grads = append_regularization_ops(params_grads, | ||
| self.regularization) | ||
|
|
||
| optimize_ops = self.create_optimization_pass(params_grads, loss, | ||
| startup_program) | ||
| return optimize_ops, params_grads | ||
| lr_decay_ops, optimize_ops = self.create_optimization_pass( | ||
| params_grads, loss, startup_program) | ||
| return lr_decay_ops, optimize_ops, params_grads, | ||
|
|
||
|
|
||
| class SGDOptimizer(Optimizer): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's smart.