Skip to content

Commit b5e7b62

Browse files
authored
[ARM]fix rnn last_h last_c output error (#8610)
1 parent 7ecc4cc commit b5e7b62

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

lite/kernels/arm/rnn_compute.cc

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -743,10 +743,10 @@ void RnnCompute::Run() {
743743
last_h_unbind[i].Resize(dims);
744744
init_h_unbind_t.push_back(&init_h_unbind[i]);
745745
last_h_unbind_t.push_back(&last_h_unbind[i]);
746+
last_h_unbind[i].mutable_data<float>();
746747
}
747748
lite::host::math::split(
748749
pre_state[0]->data<float>(), init_h_unbind_t, 0, stride1);
749-
lite::host::math::split(state[0]->data<float>(), last_h_unbind_t, 0, stride1);
750750

751751
if ("LSTM" == mode) {
752752
for (int i = 0; i < pre_state[1]->dims()[0]; i++) {
@@ -758,11 +758,10 @@ void RnnCompute::Run() {
758758
last_c_unbind[i].Resize(dims);
759759
init_c_unbind_t.push_back(&init_c_unbind[i]);
760760
last_c_unbind_t.push_back(&last_c_unbind[i]);
761+
last_c_unbind[i].mutable_data<float>();
761762
}
762763
lite::host::math::split(
763764
pre_state[1]->data<float>(), init_c_unbind_t, 0, stride2);
764-
lite::host::math::split(
765-
state[1]->data<float>(), last_c_unbind_t, 0, stride2);
766765
}
767766

768767
std::vector<Tensor> output_vec(2);
@@ -801,6 +800,12 @@ void RnnCompute::Run() {
801800
RUN_RNN_LAYER(i, output_holder, false, 0);
802801
}
803802
}
803+
804+
lite::arm::math::concat_func<float>(last_h_unbind_t, 0, state[0]);
805+
if ("LSTM" == mode) {
806+
lite::arm::math::concat_func<float>(last_c_unbind_t, 0, state[1]);
807+
}
808+
804809
// output_holder != output
805810
if (num_layers % 2 == 0) {
806811
output->CopyDataFrom(*output_holder);

0 commit comments

Comments
 (0)