@@ -2680,3 +2680,76 @@ def test_SpatialEncoder(max_dist, num_kernels, num_heads):
2680
2680
encoding3d_2 = model_3 (coord , node_type )
2681
2681
assert encoding3d_1 .shape == (bsz , max_num_nodes , max_num_nodes , num_heads )
2682
2682
assert encoding3d_2 .shape == (bsz , max_num_nodes , max_num_nodes , num_heads )
2683
+
2684
+
2685
+ @pytest .mark .parametrize ("residual" , [True , False ])
2686
+ def test_conv_with_zero_nodes_bugfix_7894 (residual ):
2687
+ """Test for PR #7894 in DGL where HeteroGraphConv with zero nodes in a
2688
+ specific node type would cause an error due to empty tensors.
2689
+ This test ensures that GATConv, GATv2Conv, and EdgeGATConv can handle
2690
+ such cases without raising errors.
2691
+ """
2692
+ # Create a heterogeneous graph with zero nodes in the "tag" type
2693
+ user_item_src = torch .tensor ([0 , 1 , 2 ])
2694
+ user_item_dst = torch .tensor ([4 , 5 , 6 ])
2695
+
2696
+ user_tag_src = torch .tensor ([], dtype = torch .int64 )
2697
+ user_tag_dst = torch .tensor ([], dtype = torch .int64 )
2698
+
2699
+ num_nodes_dict = {
2700
+ "user" : 5 ,
2701
+ "item" : 10 ,
2702
+ "tag" : 0 ,
2703
+ }
2704
+
2705
+ data_dict = {
2706
+ ("user" , "buys" , "item" ): (user_item_src , user_item_dst ),
2707
+ ("user" , "likes" , "tag" ): (user_tag_src , user_tag_dst ),
2708
+ }
2709
+
2710
+ g = dgl .heterograph (data_dict , num_nodes_dict = num_nodes_dict )
2711
+
2712
+ feat_dim = 16
2713
+ node_features = {
2714
+ "user" : torch .randn (num_nodes_dict ["user" ], feat_dim ),
2715
+ "item" : torch .randn (num_nodes_dict ["item" ], feat_dim ),
2716
+ "tag" : torch .randn (num_nodes_dict ["tag" ], feat_dim ),
2717
+ }
2718
+ edge_features = {
2719
+ ("user" , "buys" , "item" ): torch .randn (g .num_edges (("user" , "buys" , "item" )), feat_dim ),
2720
+ ("user" , "likes" , "tag" ): torch .randn (g .num_edges (("user" , "likes" , "tag" )), feat_dim ),
2721
+ }
2722
+
2723
+ # Test GATConv with zero nodes in "tag" type
2724
+ conv = nn .HeteroGraphConv ({
2725
+ ("user" , "buys" , "item" ): nn .GATConv (16 , 2 , num_heads = 2 , residual = residual ),
2726
+ ("user" , "likes" , "tag" ): nn .GATConv (16 , 2 , num_heads = 2 , residual = residual ),
2727
+ }, aggregate = "sum" )
2728
+ out = conv (g , node_features )
2729
+ assert out ["item" ].shape == (10 , 2 , 2 )
2730
+ assert out ["tag" ].shape == (0 , 2 , 2 )
2731
+ assert "user" not in out
2732
+
2733
+ # Test GATv2Conv with zero nodes in "tag" type
2734
+ conv_v2 = nn .HeteroGraphConv ({
2735
+ ("user" , "buys" , "item" ): nn .GATv2Conv (16 , 2 , num_heads = 2 , residual = residual ),
2736
+ ("user" , "likes" , "tag" ): nn .GATv2Conv (16 , 2 , num_heads = 2 , residual = residual ),
2737
+ }, aggregate = "sum" )
2738
+ out_v2 = conv_v2 (g , node_features )
2739
+ assert out_v2 ["item" ].shape == (10 , 2 , 2 )
2740
+ assert out_v2 ["tag" ].shape == (0 , 2 , 2 )
2741
+ assert "user" not in out_v2
2742
+
2743
+ # Test EdgeGATConv with zero nodes in "tag" type
2744
+ edge_conv = nn .HeteroGraphConv ({
2745
+ ("user" , "buys" , "item" ): nn .EdgeGATConv (16 , 16 , 2 , num_heads = 2 , residual = residual ),
2746
+ ("user" , "likes" , "tag" ): nn .EdgeGATConv (16 , 16 , 2 , num_heads = 2 , residual = residual ),
2747
+ }, aggregate = "sum" )
2748
+ mod_kwargs = {
2749
+ "buys" : {"edge_feat" : edge_features [("user" , "buys" , "item" )]},
2750
+ "likes" : {"edge_feat" : edge_features [("user" , "likes" , "tag" )]},
2751
+ }
2752
+ out_edge = edge_conv (g , node_features , mod_kwargs = mod_kwargs )
2753
+ assert out_edge ["item" ].shape == (10 , 2 , 2 )
2754
+ assert out_edge ["tag" ].shape == (0 , 2 , 2 )
2755
+ assert "user" not in out_edge
0 commit comments