Skip to content

Commit ebc8a1b

Browse files
committed
Fix RNN OP multi-threads predict bug (PaddlePaddle#41529)
1 parent ae34db3 commit ebc8a1b

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

paddle/phi/kernels/cpu/rnn_kernel.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -832,11 +832,13 @@ void RnnKernel(const Context& dev_ctx,
832832
DenseTensor* dropout_state,
833833
std::vector<DenseTensor*> state,
834834
DenseTensor* reserve) {
835-
if (dropout_state->IsInitialized()) {
836-
if (dropout_state->numel() != out->numel()) dropout_state->clear();
835+
if (!is_test) {
836+
if (dropout_state->IsInitialized()) {
837+
if (dropout_state->numel() != out->numel()) dropout_state->clear();
838+
}
839+
const auto& out_dim = out->dims();
840+
Full<uint8_t>(dev_ctx, {out_dim.Get(), out_dim.size()}, 1, dropout_state);
837841
}
838-
const auto& out_dim = out->dims();
839-
Full<uint8_t>(dev_ctx, {out_dim.Get(), out_dim.size()}, 1, dropout_state);
840842

841843
// init the output and allocate the memory
842844
dev_ctx.template Alloc<T>(out);

0 commit comments

Comments
 (0)