Skip to content

Commit e4dcf0b

Browse files
authored
Fix RNN OP multi-threads predict bug (#41529) (#41560)
1 parent b810961 commit e4dcf0b

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

paddle/phi/kernels/cpu/rnn_kernel.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -832,11 +832,13 @@ void RnnKernel(const Context& dev_ctx,
832832
DenseTensor* dropout_state,
833833
std::vector<DenseTensor*> state,
834834
DenseTensor* reserve) {
835-
if (dropout_state->IsInitialized()) {
836-
if (dropout_state->numel() != out->numel()) dropout_state->clear();
835+
if (!is_test) {
836+
if (dropout_state->IsInitialized()) {
837+
if (dropout_state->numel() != out->numel()) dropout_state->clear();
838+
}
839+
const auto& out_dim = out->dims();
840+
Full<uint8_t>(dev_ctx, {out_dim.Get(), out_dim.size()}, 1, dropout_state);
837841
}
838-
const auto& out_dim = out->dims();
839-
Full<uint8_t>(dev_ctx, {out_dim.Get(), out_dim.size()}, 1, dropout_state);
840842

841843
// init the output and allocate the memory
842844
dev_ctx.template Alloc<T>(out);

0 commit comments

Comments
 (0)