Skip to content

Commit 3d16000

Browse files
authored
[Bugfix][nn/PyTorch]: add checks to avoid view/reshape (0, -1, *) on empty tensors (#7894)
1 parent 743e65f commit 3d16000

File tree

4 files changed

+88
-12
lines changed

4 files changed

+88
-12
lines changed

python/dgl/nn/pytorch/conv/edgegatconv.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -368,10 +368,11 @@ def forward(self, graph, feat, edge_feat, get_attention=False):
368368
# Residual.
369369
if self.res_fc is not None:
370370
# Use -1 rather than self._num_heads to handle broadcasting.
371-
resval = self.res_fc(h_dst).view(
372-
*dst_prefix_shape, -1, self._out_feats
373-
)
374-
rst = rst + resval
371+
if h_dst.numel() != 0:
372+
resval = self.res_fc(h_dst).view(
373+
*dst_prefix_shape, -1, self._out_feats
374+
)
375+
rst = rst + resval
375376
# Bias.
376377
if self.bias is not None:
377378
rst = rst + self.bias.view(

python/dgl/nn/pytorch/conv/gatconv.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -348,10 +348,11 @@ def forward(self, graph, feat, edge_weight=None, get_attention=False):
348348
# residual
349349
if self.res_fc is not None:
350350
# Use -1 rather than self._num_heads to handle broadcasting
351-
resval = self.res_fc(h_dst).view(
352-
*dst_prefix_shape, -1, self._out_feats
353-
)
354-
rst = rst + resval
351+
if h_dst.numel() != 0:
352+
resval = self.res_fc(h_dst).view(
353+
*dst_prefix_shape, -1, self._out_feats
354+
)
355+
rst = rst + resval
355356
# bias
356357
if self.has_explicit_bias:
357358
rst = rst + self.bias.view(

python/dgl/nn/pytorch/conv/gatv2conv.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -320,10 +320,11 @@ def forward(self, graph, feat, get_attention=False):
320320
rst = graph.dstdata["ft"]
321321
# residual
322322
if self.res_fc is not None:
323-
resval = self.res_fc(h_dst).view(
324-
h_dst.shape[0], -1, self._out_feats
325-
)
326-
rst = rst + resval
323+
if h_dst.numel() != 0:
324+
resval = self.res_fc(h_dst).view(
325+
h_dst.shape[0], -1, self._out_feats
326+
)
327+
rst = rst + resval
327328
# activation
328329
if self.activation:
329330
rst = self.activation(rst)

tests/python/pytorch/nn/test_nn.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2680,3 +2680,76 @@ def test_SpatialEncoder(max_dist, num_kernels, num_heads):
26802680
encoding3d_2 = model_3(coord, node_type)
26812681
assert encoding3d_1.shape == (bsz, max_num_nodes, max_num_nodes, num_heads)
26822682
assert encoding3d_2.shape == (bsz, max_num_nodes, max_num_nodes, num_heads)
2683+
2684+
2685+
@pytest.mark.parametrize("residual", [True, False])
2686+
def test_conv_with_zero_nodes_bugfix_7894(residual):
2687+
"""Test for PR #7894 in DGL where HeteroGraphConv with zero nodes in a
2688+
specific node type would cause an error due to empty tensors.
2689+
This test ensures that GATConv, GATv2Conv, and EdgeGATConv can handle
2690+
such cases without raising errors.
2691+
"""
2692+
# Create a heterogeneous graph with zero nodes in the "tag" type
2693+
user_item_src = torch.tensor([0, 1, 2])
2694+
user_item_dst = torch.tensor([4, 5, 6])
2695+
2696+
user_tag_src = torch.tensor([], dtype=torch.int64)
2697+
user_tag_dst = torch.tensor([], dtype=torch.int64)
2698+
2699+
num_nodes_dict = {
2700+
"user": 5,
2701+
"item": 10,
2702+
"tag": 0,
2703+
}
2704+
2705+
data_dict = {
2706+
("user", "buys", "item"): (user_item_src, user_item_dst),
2707+
("user", "likes", "tag"): (user_tag_src, user_tag_dst),
2708+
}
2709+
2710+
g = dgl.heterograph(data_dict, num_nodes_dict=num_nodes_dict)
2711+
2712+
feat_dim = 16
2713+
node_features = {
2714+
"user": torch.randn(num_nodes_dict["user"], feat_dim),
2715+
"item": torch.randn(num_nodes_dict["item"], feat_dim),
2716+
"tag": torch.randn(num_nodes_dict["tag"], feat_dim),
2717+
}
2718+
edge_features = {
2719+
("user", "buys", "item"): torch.randn(g.num_edges(("user", "buys", "item")), feat_dim),
2720+
("user", "likes", "tag"): torch.randn(g.num_edges(("user", "likes", "tag")), feat_dim),
2721+
}
2722+
2723+
# Test GATConv with zero nodes in "tag" type
2724+
conv = nn.HeteroGraphConv({
2725+
("user", "buys", "item"): nn.GATConv(16, 2, num_heads=2, residual=residual),
2726+
("user", "likes", "tag"): nn.GATConv(16, 2, num_heads=2, residual=residual),
2727+
}, aggregate="sum")
2728+
out = conv(g, node_features)
2729+
assert out["item"].shape == (10, 2, 2)
2730+
assert out["tag"].shape == (0, 2, 2)
2731+
assert "user" not in out
2732+
2733+
# Test GATv2Conv with zero nodes in "tag" type
2734+
conv_v2 = nn.HeteroGraphConv({
2735+
("user", "buys", "item"): nn.GATv2Conv(16, 2, num_heads=2, residual=residual),
2736+
("user", "likes", "tag"): nn.GATv2Conv(16, 2, num_heads=2, residual=residual),
2737+
}, aggregate="sum")
2738+
out_v2 = conv_v2(g, node_features)
2739+
assert out_v2["item"].shape == (10, 2, 2)
2740+
assert out_v2["tag"].shape == (0, 2, 2)
2741+
assert "user" not in out_v2
2742+
2743+
# Test EdgeGATConv with zero nodes in "tag" type
2744+
edge_conv = nn.HeteroGraphConv({
2745+
("user", "buys", "item"): nn.EdgeGATConv(16, 16, 2, num_heads=2, residual=residual),
2746+
("user", "likes", "tag"): nn.EdgeGATConv(16, 16, 2, num_heads=2, residual=residual),
2747+
}, aggregate="sum")
2748+
mod_kwargs = {
2749+
"buys": {"edge_feat": edge_features[("user", "buys", "item")]},
2750+
"likes": {"edge_feat": edge_features[("user", "likes", "tag")]},
2751+
}
2752+
out_edge = edge_conv(g, node_features, mod_kwargs=mod_kwargs)
2753+
assert out_edge["item"].shape == (10, 2, 2)
2754+
assert out_edge["tag"].shape == (0, 2, 2)
2755+
assert "user" not in out_edge

0 commit comments

Comments
 (0)