1
- """Contains a gnn block that uses Pytorch Geometric to support different types of GNNs"""
1
+ """ Contains a gnn block that uses Pytorch Geometric to support different types of GNNs"""
2
2
from collections import OrderedDict
3
- from typing import Sequence , Callable
3
+ from typing import Sequence , Callable , Any
4
4
5
5
import torch
6
6
from torch import nn as nn
7
7
8
- from torch_geometric .nn import GCNConv , SAGEConv , GraphConv
8
+ from torch_geometric .nn import GCNConv , SAGEConv , GraphConv , GATConv
9
9
10
10
from maze .core .annotations import override
11
11
from maze .core .utils .factory import Factory
@@ -30,35 +30,59 @@ def create_dummy_edge_index_tensor() -> torch.Tensor:
30
30
return create_dummy_edge_index_tensor
31
31
32
32
33
- SUPPORTED_GNNS = ['gcn' , 'sage' , 'graph_conv' ]
33
+ SUPPORTED_GNNS = ['gcn' , 'sage' , 'graph_conv' , 'gat' ]
34
34
35
35
class GNNLayerPyG (nn .Module ):
36
- """Simple graph convolution layer.
36
+ """Simple graph neural network layer.
37
37
38
38
:param in_features: The number of input features.
39
39
:param out_features: The number of output features.
40
- :param bias: Whether to include bias in the PyG layer.
41
40
:param gnn_type: The type of GNN layer.
41
+ :param gnn_kwargs: Additional keyword arguments passed to the underlying PyG layer.
42
42
"""
43
43
44
- def __init__ (self , in_features : int , out_features : int , bias : bool , gnn_type : str ) -> None :
44
+ def __init__ (
45
+ self ,
46
+ in_features : int ,
47
+ out_features : int ,
48
+ gnn_type : str ,
49
+ gnn_kwargs : dict [str , Any ] | None
50
+ ) -> None :
45
51
super ().__init__ ()
46
52
47
53
self .in_features = in_features
48
54
self .out_features = out_features
49
- self .bias = bias
50
-
51
- if gnn_type .lower () == 'gcn' :
52
- self .gnn_layer = GCNConv (in_features , out_features , bias = bias )
53
- elif gnn_type .lower () == 'sage' :
54
- self .gnn_layer = SAGEConv (in_features , out_features , bias = bias )
55
- elif gnn_type .lower () == 'graph_conv' :
56
- self .gnn_layer = GraphConv (in_features , out_features , bias = bias )
55
+ self .gnn_type = gnn_type .lower ()
56
+ self .gnn_kwargs = gnn_kwargs if gnn_kwargs is not None else {}
57
+
58
+ if self .gnn_type == 'gcn' :
59
+ self .gnn_layer = GCNConv (
60
+ in_channels = in_features ,
61
+ out_channels = out_features ,
62
+ ** self .gnn_kwargs
63
+ )
64
+ elif self .gnn_type == 'sage' :
65
+ self .gnn_layer = SAGEConv (
66
+ in_channels = in_features ,
67
+ out_channels = out_features ,
68
+ ** self .gnn_kwargs
69
+ )
70
+ elif self .gnn_type == 'graph_conv' :
71
+ self .gnn_layer = GraphConv (
72
+ in_channels = in_features ,
73
+ out_channels = out_features ,
74
+ ** self .gnn_kwargs
75
+ )
76
+ elif self .gnn_type == 'gat' :
77
+ # For GAT with edge attributes, set "edge_dim" in gnn_kwargs.
78
+ self .gnn_layer = GATConv (
79
+ in_channels = in_features ,
80
+ out_channels = out_features ,
81
+ ** self .gnn_kwargs
82
+ )
57
83
else :
58
84
raise ValueError (f'Unsupported GNN type: { gnn_type } . Supported GNNs are { SUPPORTED_GNNS } ' )
59
85
60
- self .gnn_type = gnn_type
61
-
62
86
@override (nn .Module )
63
87
def forward (self , x : torch .Tensor , edge_index : torch .Tensor , edge_attr : torch .Tensor ) -> torch .Tensor :
64
88
"""
@@ -69,39 +93,66 @@ def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Te
69
93
:param edge_attr: The edge attributes, [E, D] or [B, E, D]. D is the edge attribute dimension.
70
94
:return: Output tensor.
71
95
"""
72
- if x .dim () == 2 : # Single graph (no batch)
73
- if self .gnn_type in ['gcn' , 'graph_conv' ]:
74
- return self .gnn_layer (x , edge_index , edge_weight = edge_attr if edge_attr .dim () == 1 else None )
75
- else :
76
- # SAGEConv does not use edge_attr
77
- return self .gnn_layer (x , edge_index )
78
-
79
- elif x .dim () == 3 : # Batched graphs
80
- batch_size , num_nodes , in_features = x .shape
81
-
82
- # Flatten batch for efficient processing
83
- x_flat = x .view (- 1 , in_features ) # [B*N, in_features]
84
- edge_index_flat = torch .cat ([edge_index [b ] + b * num_nodes for b in range (batch_size )], dim = 1 )
85
- edge_attr_flat = torch .cat ([edge_attr [b ] for b in range (batch_size )],
86
- dim = 0 ) if edge_attr is not None else None
87
-
88
- if self .gnn_type in ['gcn' , 'graph_conv' ]:
89
- out = self .gnn_layer (
90
- x_flat , edge_index_flat ,
91
- edge_weight = edge_attr_flat if (edge_attr_flat is not None and edge_attr_flat .dim () == 1 ) else None
92
- )
93
- else :
94
- out = self .gnn_layer (x_flat , edge_index_flat )
95
-
96
- return out .view (batch_size , num_nodes , - 1 ) # Reshape back
97
96
97
+ reshaped = False
98
+ if x .dim () == 2 :
99
+ x = x .unsqueeze (0 )
100
+ edge_index = edge_index .unsqueeze (0 )
101
+ edge_attr = edge_attr .unsqueeze (0 )
102
+ reshaped = True
103
+
104
+ # Expect x with shape [B, n_nodes, in_features].
105
+ batch_size , n_nodes , in_features = x .shape
106
+
107
+ # Flatten the batch for node features: (B*n_nodes, in_features)
108
+ x_flat = x .view (- 1 , in_features )
109
+
110
+ # Flatten the batch for edges
111
+ edge_index_list = []
112
+ edge_attr_list = []
113
+ for b in range (batch_size ):
114
+ offset = b * n_nodes
115
+ edge_index_batch = edge_index [b ] + offset # (2, E)
116
+ edge_index_list .append (edge_index_batch )
117
+ if edge_attr is not None :
118
+ edge_attr_list .append (edge_attr [b ]) # (E, D) or (E,)
119
+
120
+ edge_index_flat = torch .cat (edge_index_list , dim = 1 ) # => (2, sum(E_b))
121
+ edge_attr_flat = torch .cat (edge_attr_list , dim = 0 ) if edge_attr_list else None
122
+
123
+ if self .gnn_type in ['gcn' , 'graph_conv' ]:
124
+ # Interpret 1D edge_attr as edge_weight if provided
125
+ edge_weight = None
126
+ if edge_attr_flat is not None and edge_attr_flat .dim () == 1 :
127
+ edge_weight = edge_attr_flat
128
+
129
+ out_flat = self .gnn_layer (x_flat , edge_index_flat , edge_weight = edge_weight )
130
+
131
+ elif self .gnn_type == 'sage' :
132
+ out_flat = self .gnn_layer (x_flat , edge_index_flat )
133
+
134
+ elif self .gnn_type == 'gat' :
135
+ # GATConv can use edge_attr if "edge_dim" is provided in gnn_kwargs
136
+ if self .gnn_layer .edge_dim is not None :
137
+ assert edge_attr_flat .shape [- 1 ] == self .gnn_layer .edge_dim , \
138
+ f'The edge feature size: { edge_attr_flat .shape [- 1 ]} must match the edge_dim: { self .gnn_layer .edge_dim } '
139
+ out_flat = self .gnn_layer (x_flat , edge_index_flat , edge_attr = edge_attr_flat )
98
140
else :
99
- raise ValueError (f"Unexpected x shape: { x .shape } , expected 2D or 3D." )
141
+ raise ValueError (f"Unsupported GNN type: { self .gnn_type } " )
142
+
143
+ # Reshape back to (B, n_nodes, out_features)
144
+ out = out_flat .view (batch_size , n_nodes , - 1 )
145
+
146
+ # Reshape back if we had inserted a batch dimension of size 1
147
+ if reshaped :
148
+ out = out .squeeze (0 ) # => (n_nodes, out_features)
149
+
150
+ return out
100
151
101
152
@override (nn .Module )
102
153
def __repr__ (self ):
103
154
txt = f'{ self .gnn_type } : ({ self .in_features } -> { self .out_features } )'
104
- txt += ' (with bias)' if self .bias else ' (without bias) '
155
+ txt += f', kwargs= { self .gnn_kwargs } '
105
156
return txt
106
157
107
158
@@ -113,41 +164,44 @@ class GNNBlockPyG(ShapeNormalizationBlock):
113
164
:param in_shapes: List of input shapes.
114
165
:param hidden_features: List containing the number of hidden features for hidden layers.
115
166
:param non_lin: The non-linearity to apply after each layer.
116
- :param bias: Whether to include bias in the GNN layers.
117
- :param gnn_type: The type of GNN layer .
167
+ :param gnn_type: The type of GNN layer
168
+ :param gnn_kwargs: Extra kwargs to pass to the GNN layers .
118
169
"""
119
170
120
171
def __init__ (
121
172
self , in_keys : str | list [str ], out_keys : str | list [str ],
122
173
in_shapes : Sequence [int ] | list [Sequence [int ]],
123
174
hidden_features : list [int ],
124
175
non_lin : str | nn .Module ,
125
- bias : bool ,
126
176
gnn_type : str ,
177
+ gnn_kwargs : dict [str , Any ] | None
127
178
):
128
179
129
180
super ().__init__ (in_keys = in_keys , out_keys = out_keys , in_shapes = in_shapes , in_num_dims = [3 ]* 3 , out_num_dims = 3 )
130
181
131
182
self .gnn_type = gnn_type
132
- self .bias = bias
183
+ self .gnn_kwargs = gnn_kwargs if gnn_kwargs is not None else {}
133
184
134
185
assert len (self .in_keys ) == 3 , \
135
- 'There should be three input keys: node feature matrix, graph edge_index, and edge attributes.'
186
+ f"Expected three input keys, got { len ( self . in_keys ) } : { self . in_keys } "
136
187
137
188
# Specify dummy dict creation function for edge_index:
138
189
self .dummy_dict_creators [1 ] = _dummy_edge_index_factory (self .in_shapes [1 ], self .in_shapes [0 ][0 ])
139
190
140
- # Init class objects
141
191
self .input_features = self .in_shapes [0 ][- 1 ]
192
+
193
+ if (gnn_type == 'gat' and gnn_kwargs is not None and 'heads' in gnn_kwargs
194
+ and ('concat' not in gnn_kwargs or gnn_kwargs ['concat' ])):
195
+ self .output_features = hidden_features [- 1 ] * gnn_kwargs ['heads' ]
196
+ else :
197
+ self .output_features = hidden_features [- 1 ]
198
+
142
199
self .hidden_features = hidden_features
143
- self .output_features = self .hidden_features [- 1 ]
144
200
145
201
self .non_lin : type [nn .Module ] = Factory (base_type = nn .Module ).type_from_name (non_lin )
146
202
147
- # Form layers dictionary
203
+ # Create the GNN layers
148
204
layer_dict = self .build_layer_dict ()
149
-
150
- # Compile network
151
205
self .net = nn .Sequential (layer_dict )
152
206
153
207
@override (ShapeNormalizationBlock )
@@ -163,11 +217,12 @@ def normalized_forward(self, block_input: dict[str, torch.Tensor]) -> dict[str,
163
217
assert edge_attr .ndim == self .in_num_dims [2 ]
164
218
165
219
assert node_feat .shape [- 1 ] == self .input_features , \
166
- f"Feature dimension should fit : { node_feat .shape [- 1 ]} vs { self .input_features } "
220
+ f"Mismatch in node feature dimension : { node_feat .shape [- 1 ]} vs expected { self .input_features } "
167
221
assert edge_index .shape [- 1 ] == edge_attr .shape [- 2 ], \
168
- "Number of edges must be consistent"
222
+ (f"Number of edges (E) must be consistent between edge_index: { edge_index .shape [- 1 ]} "
223
+ f"and edge_attr: { edge_attr .shape [- 2 ]} " )
169
224
170
- # forward pass
225
+ # Forward pass
171
226
x = node_feat
172
227
for layer in self .net :
173
228
if isinstance (layer , GNNLayerPyG ):
@@ -185,32 +240,37 @@ def normalized_forward(self, block_input: dict[str, torch.Tensor]) -> dict[str,
185
240
186
241
def build_layer_dict (self ) -> OrderedDict :
187
242
"""Compiles a block-specific dictionary of network layers.
188
-
189
- This could be overwritten by derived layers (e.g. to get a 'BatchNormalizedConvolutionBlock').
190
-
191
- :return: Ordered dictionary of torch modules [str, nn.Module].
243
+ :return: Ordered dictionary of torch modules
192
244
"""
193
245
layer_dict = OrderedDict ()
194
246
in_feats = self .input_features
195
247
196
248
for layer_idx , out_feats in enumerate (self .hidden_features ):
249
+
197
250
layer_dict [f'{ self .gnn_type } _{ layer_idx } ' ] = GNNLayerPyG (
198
251
in_features = in_feats ,
199
252
out_features = out_feats ,
200
- bias = self .bias ,
201
253
gnn_type = self .gnn_type ,
254
+ gnn_kwargs = self .gnn_kwargs
202
255
)
203
- # Add activation function only for intermediate layers
256
+ # Insert activation function after each hidden layer except the last
204
257
if layer_idx < len (self .hidden_features ) - 1 :
205
- layer_dict [f'activation_{ layer_idx } _{ self .non_lin .__name__ } ' ] = self .non_lin ()
206
- in_feats = out_feats
258
+ layer_name = f'activation_{ layer_idx } _{ self .non_lin .__name__ } '
259
+ layer_dict [layer_name ] = self .non_lin ()
260
+
261
+ if self .gnn_type == 'gat' and 'heads' in self .gnn_kwargs and \
262
+ ('concat' not in self .gnn_kwargs or self .gnn_kwargs ['concat' ]):
263
+ in_feats = out_feats * self .gnn_kwargs ['heads' ]
264
+ else :
265
+ in_feats = out_feats
207
266
208
267
return layer_dict
209
268
210
269
def __repr__ (self ):
211
- txt = f'{ self .__class__ .__name__ } '
212
- txt += f'({ self .non_lin .__name__ } )'
213
- txt += '\n \t ' + f'({ self .input_features } ->' + '->' .join ([f'{ h } ' for h in self .hidden_features ]) + ')'
214
- txt += f'\n \t Bias: { self .bias } '
215
- txt += f'\n \t Out Shapes: { self .out_shapes ()} '
216
- return txt
270
+ txt = (
271
+ f"{ self .__class__ .__name__ } ({ self .non_lin .__name__ } )\n "
272
+ f"\t ({ self .input_features } ->" + "->" .join ([f"{ h } " for h in self .hidden_features ]) + ")\n "
273
+ )
274
+ txt += f"\n \t GNN kwargs: { self .gnn_kwargs } "
275
+ txt += f"\n \t Out Shapes: { self .out_shapes ()} "
276
+ return txt
0 commit comments