Skip to content

Commit 78cc292

Browse files
authored
Switch test_transformer to eager mode and fix roll error (#41548) (#41617)
1 parent 74626f6 commit 78cc292

File tree

2 files changed

+3
-1
lines changed

2 files changed

+3
-1
lines changed

python/paddle/fluid/tests/unittests/dygraph_to_static/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ set(DY2ST_EAGER_TEST_ENVS ${GC_ENVS} FLAGS_enable_eager_mode=1)
66
set(TEST_EAGER_OPS test_bmn test_break_continue test_ifelse test_loop test_mnist_amp
77
test_mnist_pure_fp16 test_mobile_net test_program_translator test_ptb_lm test_reinforcement_learning
88
test_resnet test_resnet_amp test_resnet_pure_fp16 test_se_resnet test_sentiment test_seq2seq
9-
test_tsm test_word2vec test_yolov3 test_bert test_cycle_gan test_lstm test_simnet)
9+
test_tsm test_word2vec test_yolov3 test_bert test_cycle_gan test_lstm test_simnet test_transformer)
1010
list(REMOVE_ITEM TEST_OPS test_lac)
1111
# NOTE(Aurelius84): In case of Windows CI, if open ON_INFER, RWLOCK of Scope will
1212
# be removed and will cause some random failed in multi-thread.

python/paddle/tensor/manipulation.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -780,6 +780,8 @@ def roll(x, shifts, axis=None, name=None):
780780
axis = []
781781

782782
if in_dygraph_mode():
783+
if isinstance(shifts, paddle.Tensor):
784+
shifts = shifts.cpu()
783785
return _C_ops.final_state_roll(x, shifts, axis)
784786

785787
if _in_legacy_dygraph():

0 commit comments

Comments
 (0)