Skip to content

Commit 881c3b3

Browse files
author
Me
committed
fix
1 parent 02617f6 commit 881c3b3

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

tf2onnx/rewriter/bilstm_rewriter.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def process_bilstm(g, bi_lstms):
5555
lstm_inputs.extend([lstm_fw.input[4], h_node.output[0], c_node.output[0]])
5656

5757
direction = "bidirectional"
58-
attr = {}
58+
attr = {"direction": direction}
5959
for name in rnn_utils.onnx_rnn_attr_mapping[rnn_utils.ONNX_RNN_TYPE.LSTM]:
6060
attr_val = lstm_fw.get_attr_value(name)
6161
if attr_val:
@@ -65,7 +65,8 @@ def process_bilstm(g, bi_lstms):
6565
for act in lstm_fw.get_attr_value("activations", [])]
6666
activations += [act.decode("utf-8")
6767
for act in lstm_bw.get_attr_value("activations", [])]
68-
attr.update({"direction": direction, "activations": activations})
68+
if activations:
69+
attr["activations"] = activations
6970

7071
bi_lstm_node = g.make_node("LSTM", lstm_inputs, attr=attr, output_count=3)
7172
all_nodes.append(bi_lstm_node)

0 commit comments

Comments
 (0)