Skip to content

Commit 20f56e6

Browse files
author
Your Name
committed
fix bug
1 parent 4d974e5 commit 20f56e6

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

paddlenlp/trl/embedding_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def get_current_rng_state(self):
7777
return {
7878
"cpu": [paddle.framework.core.default_cpu_generator().get_state()],
7979
"cuda": [paddle.get_rng_state()],
80-
"hybrid": [fleet.meta_parallel.get_rng_state_tracker().get_states_tracker()],
80+
"hybrid": [fleet.meta_parallel.get_rng_state_tracker().get_states_tracker()] if self.args.use_hybrid_parallel else []
8181
}
8282

8383
def reset_rng_state(self, states, index=0):
@@ -86,13 +86,13 @@ def reset_rng_state(self, states, index=0):
8686
raise ValueError("The length of state should be 3")
8787
cpu_state = states["cpu"][index]
8888
cuda_state = states["cuda"][index]
89-
hybrid_state = states["hybrid"][index]
9089
paddle.framework.core.default_cpu_generator().set_state(cpu_state)
9190
# TODO(daisiming): support xpu and other custom devices.
9291
if core.is_compiled_with_cuda():
9392
for j in range(core.get_cuda_device_count()):
9493
core.default_cuda_generator(j).set_state(cuda_state[j])
9594
if self.args.use_hybrid_parallel:
95+
hybrid_state = states["hybrid"][index]
9696
fleet.meta_parallel.get_rng_state_tracker().set_states_tracker(hybrid_state)
9797

9898
def accum_forward_backward(self, model):

0 commit comments

Comments
 (0)