Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
output=attention_last_node.output[0],
add_qk_str=add_qk,
scale=None,
causal=(add_mask is not None),
causal=False,
)
if new_node is None:
logger.debug("fuse_attention: failed to create fused node")
Expand Down
Loading