-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Add some dist-training robust cases into fluid benchmark test #11207
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
72bb214
3bd8f9e
2da0ef7
3bf93b3
8041e8d
4dd0ded
7e0afd5
e67392e
9c2e68d
d11e2bf
2da70cc
95cbb43
e140844
4779338
0a90eee
c950d22
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,6 +26,10 @@ | |
| import paddle.fluid.core as core | ||
| import paddle.fluid.framework as framework | ||
| from paddle.fluid.executor import Executor | ||
| from models.model_base import get_decay_learning_rate | ||
| from models.model_base import get_regularization | ||
| from models.model_base import set_error_clip | ||
| from models.model_base import set_gradient_clip | ||
|
|
||
|
|
||
| def lstm_step(x_t, hidden_t_prev, cell_t_prev, size): | ||
|
|
@@ -50,7 +54,7 @@ def linear(inputs): | |
|
|
||
|
|
||
| def seq_to_seq_net(embedding_dim, encoder_size, decoder_size, source_dict_dim, | ||
| target_dict_dim, is_generating, beam_size, max_length): | ||
| target_dict_dim, is_generating, beam_size, max_length, args): | ||
| """Construct a seq2seq network.""" | ||
|
|
||
| def bi_lstm_encoder(input_seq, gate_size): | ||
|
|
@@ -99,6 +103,8 @@ def bi_lstm_encoder(input_seq, gate_size): | |
| size=decoder_size, | ||
| bias_attr=False, | ||
| act='tanh') | ||
| set_error_clip(args.error_clip_method, encoded_proj.name, | ||
| args.error_clip_min, args.error_clip_max) | ||
|
|
||
| def lstm_decoder_with_attention(target_embedding, encoder_vec, encoder_proj, | ||
| decoder_boot, decoder_size): | ||
|
|
@@ -211,12 +217,24 @@ def get_model(args): | |
| dict_size, | ||
| False, | ||
| beam_size=beam_size, | ||
| max_length=max_length) | ||
| max_length=max_length, | ||
| args=args) | ||
|
|
||
| # clone from default main program | ||
| inference_program = fluid.default_main_program().clone() | ||
|
|
||
| optimizer = fluid.optimizer.Adam(learning_rate=args.learning_rate) | ||
| # set gradient clip | ||
| set_gradient_clip(args.gradient_clip_method, args.gradient_clip_norm) | ||
|
||
|
|
||
| optimizer = fluid.optimizer.Adam( | ||
| learning_rate=get_decay_learning_rate( | ||
| decay_method=args.learning_rate_decay_method, | ||
| learning_rate=args.learning_rate, | ||
| decay_steps=args.learning_rate_decay_steps, | ||
| decay_rate=args.learning_rate_decay_rate), | ||
| regularization=get_regularization( | ||
| regularizer_method=args.weight_decay_regularizer_method, | ||
| regularizer_coeff=args.weight_decay_regularizer_coeff)) | ||
|
|
||
| train_batch_generator = paddle.batch( | ||
| paddle.reader.shuffle( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model_base is not uploaded?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for review, I added the benchmark/fluid/models/model_base.py file in next commit