Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,8 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
],
[1, 2, 1, 0, 0, 0, 0, 0, 0],
)
causal_mask_nodes_1 = None
causal_mask_nodes_2 = None
if add_qk_nodes is not None:
add_qk = add_mask.input[1]
else:
Expand All @@ -302,6 +304,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
["Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"],
[causal_mask_input_index, 0, 0, 0, 0],
)

if causal_mask_nodes_1 is None and causal_mask_nodes_2 is None:
logger.debug("fuse_attention: failed to match causal mask subgraph")
return
Expand All @@ -320,7 +323,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
output=attention_last_node.output[0],
add_qk_str=add_qk,
scale=None,
causal=(add_mask is not None),
causal=(causal_mask_nodes_1 is not None) or (causal_mask_nodes_2 is not None),
)
if new_node is None:
logger.debug("fuse_attention: failed to create fused node")
Expand Down
Loading