Skip to content

Commit fb5d721

Browse files
author
EnliteAI Bot
committed
RL-2080: Add GAT to pyg gnn block
(Issue RL-2080 - Add GraphConv and GAT to the pyg gnn block)
1 parent 41a5b4e commit fb5d721

File tree

2 files changed

+238
-74
lines changed

2 files changed

+238
-74
lines changed
Lines changed: 131 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
"""Contains a gnn block that uses Pytorch Geometric to support different types of GNNs"""
1+
""" Contains a gnn block that uses Pytorch Geometric to support different types of GNNs"""
22
from collections import OrderedDict
3-
from typing import Sequence, Callable
3+
from typing import Sequence, Callable, Any
44

55
import torch
66
from torch import nn as nn
77

8-
from torch_geometric.nn import GCNConv, SAGEConv, GraphConv
8+
from torch_geometric.nn import GCNConv, SAGEConv, GraphConv, GATConv
99

1010
from maze.core.annotations import override
1111
from maze.core.utils.factory import Factory
@@ -30,35 +30,59 @@ def create_dummy_edge_index_tensor() -> torch.Tensor:
3030
return create_dummy_edge_index_tensor
3131

3232

33-
SUPPORTED_GNNS = ['gcn', 'sage', 'graph_conv']
33+
SUPPORTED_GNNS = ['gcn', 'sage', 'graph_conv', 'gat']
3434

3535
class GNNLayerPyG(nn.Module):
36-
"""Simple graph convolution layer.
36+
"""Simple graph neural network layer.
3737
3838
:param in_features: The number of input features.
3939
:param out_features: The number of output features.
40-
:param bias: Whether to include bias in the PyG layer.
4140
:param gnn_type: The type of GNN layer.
41+
:param gnn_kwargs: Additional keyword arguments passed to the underlying PyG layer.
4242
"""
4343

44-
def __init__(self, in_features: int, out_features: int, bias: bool, gnn_type: str) -> None:
44+
def __init__(
45+
self,
46+
in_features: int,
47+
out_features: int,
48+
gnn_type: str,
49+
gnn_kwargs: dict[str, Any] | None
50+
) -> None:
4551
super().__init__()
4652

4753
self.in_features = in_features
4854
self.out_features = out_features
49-
self.bias = bias
50-
51-
if gnn_type.lower() == 'gcn':
52-
self.gnn_layer = GCNConv(in_features, out_features, bias=bias)
53-
elif gnn_type.lower() == 'sage':
54-
self.gnn_layer = SAGEConv(in_features, out_features, bias=bias)
55-
elif gnn_type.lower() == 'graph_conv':
56-
self.gnn_layer = GraphConv(in_features, out_features, bias=bias)
55+
self.gnn_type = gnn_type.lower()
56+
self.gnn_kwargs = gnn_kwargs if gnn_kwargs is not None else {}
57+
58+
if self.gnn_type == 'gcn':
59+
self.gnn_layer = GCNConv(
60+
in_channels=in_features,
61+
out_channels=out_features,
62+
**self.gnn_kwargs
63+
)
64+
elif self.gnn_type == 'sage':
65+
self.gnn_layer = SAGEConv(
66+
in_channels=in_features,
67+
out_channels=out_features,
68+
**self.gnn_kwargs
69+
)
70+
elif self.gnn_type == 'graph_conv':
71+
self.gnn_layer = GraphConv(
72+
in_channels=in_features,
73+
out_channels=out_features,
74+
**self.gnn_kwargs
75+
)
76+
elif self.gnn_type == 'gat':
77+
# For GAT with edge attributes, set "edge_dim" in gnn_kwargs.
78+
self.gnn_layer = GATConv(
79+
in_channels=in_features,
80+
out_channels=out_features,
81+
**self.gnn_kwargs
82+
)
5783
else:
5884
raise ValueError(f'Unsupported GNN type: {gnn_type}. Supported GNNs are {SUPPORTED_GNNS}')
5985

60-
self.gnn_type = gnn_type
61-
6286
@override(nn.Module)
6387
def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor:
6488
"""
@@ -69,39 +93,66 @@ def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Te
6993
:param edge_attr: The edge attributes, [E, D] or [B, E, D]. D is the edge attribute dimension.
7094
:return: Output tensor.
7195
"""
72-
if x.dim() == 2: # Single graph (no batch)
73-
if self.gnn_type in ['gcn', 'graph_conv']:
74-
return self.gnn_layer(x, edge_index, edge_weight=edge_attr if edge_attr.dim() == 1 else None)
75-
else:
76-
# SAGEConv does not use edge_attr
77-
return self.gnn_layer(x, edge_index)
78-
79-
elif x.dim() == 3: # Batched graphs
80-
batch_size, num_nodes, in_features = x.shape
81-
82-
# Flatten batch for efficient processing
83-
x_flat = x.view(-1, in_features) # [B*N, in_features]
84-
edge_index_flat = torch.cat([edge_index[b] + b * num_nodes for b in range(batch_size)], dim=1)
85-
edge_attr_flat = torch.cat([edge_attr[b] for b in range(batch_size)],
86-
dim=0) if edge_attr is not None else None
87-
88-
if self.gnn_type in ['gcn', 'graph_conv']:
89-
out = self.gnn_layer(
90-
x_flat, edge_index_flat,
91-
edge_weight=edge_attr_flat if (edge_attr_flat is not None and edge_attr_flat.dim() == 1) else None
92-
)
93-
else:
94-
out = self.gnn_layer(x_flat, edge_index_flat)
95-
96-
return out.view(batch_size, num_nodes, -1) # Reshape back
9796

97+
reshaped = False
98+
if x.dim() == 2:
99+
x = x.unsqueeze(0)
100+
edge_index = edge_index.unsqueeze(0)
101+
edge_attr = edge_attr.unsqueeze(0)
102+
reshaped = True
103+
104+
# Expect x with shape [B, n_nodes, in_features].
105+
batch_size, n_nodes, in_features = x.shape
106+
107+
# Flatten the batch for node features: (B*n_nodes, in_features)
108+
x_flat = x.view(-1, in_features)
109+
110+
# Flatten the batch for edges
111+
edge_index_list = []
112+
edge_attr_list = []
113+
for b in range(batch_size):
114+
offset = b * n_nodes
115+
edge_index_batch = edge_index[b] + offset # (2, E)
116+
edge_index_list.append(edge_index_batch)
117+
if edge_attr is not None:
118+
edge_attr_list.append(edge_attr[b]) # (E, D) or (E,)
119+
120+
edge_index_flat = torch.cat(edge_index_list, dim=1) # => (2, sum(E_b))
121+
edge_attr_flat = torch.cat(edge_attr_list, dim=0) if edge_attr_list else None
122+
123+
if self.gnn_type in ['gcn', 'graph_conv']:
124+
# Interpret 1D edge_attr as edge_weight if provided
125+
edge_weight = None
126+
if edge_attr_flat is not None and edge_attr_flat.dim() == 1:
127+
edge_weight = edge_attr_flat
128+
129+
out_flat = self.gnn_layer(x_flat, edge_index_flat, edge_weight=edge_weight)
130+
131+
elif self.gnn_type == 'sage':
132+
out_flat = self.gnn_layer(x_flat, edge_index_flat)
133+
134+
elif self.gnn_type == 'gat':
135+
# GATConv can use edge_attr if "edge_dim" is provided in gnn_kwargs
136+
if self.gnn_layer.edge_dim is not None:
137+
assert edge_attr_flat.shape[-1] == self.gnn_layer.edge_dim, \
138+
f'The edge feature size: {edge_attr_flat.shape[-1]} must match the edge_dim: {self.gnn_layer.edge_dim}'
139+
out_flat = self.gnn_layer(x_flat, edge_index_flat, edge_attr=edge_attr_flat)
98140
else:
99-
raise ValueError(f"Unexpected x shape: {x.shape}, expected 2D or 3D.")
141+
raise ValueError(f"Unsupported GNN type: {self.gnn_type}")
142+
143+
# Reshape back to (B, n_nodes, out_features)
144+
out = out_flat.view(batch_size, n_nodes, -1)
145+
146+
# Reshape back if we had inserted a batch dimension of size 1
147+
if reshaped:
148+
out = out.squeeze(0) # => (n_nodes, out_features)
149+
150+
return out
100151

101152
@override(nn.Module)
102153
def __repr__(self):
103154
txt = f'{self.gnn_type}: ({self.in_features} -> {self.out_features})'
104-
txt += ' (with bias)' if self.bias else ' (without bias)'
155+
txt += f', kwargs={self.gnn_kwargs}'
105156
return txt
106157

107158

@@ -113,41 +164,44 @@ class GNNBlockPyG(ShapeNormalizationBlock):
113164
:param in_shapes: List of input shapes.
114165
:param hidden_features: List containing the number of hidden features for hidden layers.
115166
:param non_lin: The non-linearity to apply after each layer.
116-
:param bias: Whether to include bias in the GNN layers.
117-
:param gnn_type: The type of GNN layer.
167+
:param gnn_type: The type of GNN layer
168+
:param gnn_kwargs: Extra kwargs to pass to the GNN layers.
118169
"""
119170

120171
def __init__(
121172
self, in_keys: str | list[str], out_keys: str | list[str],
122173
in_shapes: Sequence[int] | list[Sequence[int]],
123174
hidden_features: list[int],
124175
non_lin: str | nn.Module,
125-
bias: bool,
126176
gnn_type: str,
177+
gnn_kwargs: dict[str, Any] | None
127178
):
128179

129180
super().__init__(in_keys=in_keys, out_keys=out_keys, in_shapes=in_shapes, in_num_dims=[3]*3, out_num_dims=3)
130181

131182
self.gnn_type = gnn_type
132-
self.bias = bias
183+
self.gnn_kwargs = gnn_kwargs if gnn_kwargs is not None else {}
133184

134185
assert len(self.in_keys) == 3, \
135-
'There should be three input keys: node feature matrix, graph edge_index, and edge attributes.'
186+
f"Expected three input keys, got {len(self.in_keys)}: {self.in_keys}"
136187

137188
# Specify dummy dict creation function for edge_index:
138189
self.dummy_dict_creators[1] = _dummy_edge_index_factory(self.in_shapes[1], self.in_shapes[0][0])
139190

140-
# Init class objects
141191
self.input_features = self.in_shapes[0][-1]
192+
193+
if (gnn_type == 'gat' and gnn_kwargs is not None and 'heads' in gnn_kwargs
194+
and ('concat' not in gnn_kwargs or gnn_kwargs['concat'])):
195+
self.output_features = hidden_features[-1] * gnn_kwargs['heads']
196+
else:
197+
self.output_features = hidden_features[-1]
198+
142199
self.hidden_features = hidden_features
143-
self.output_features = self.hidden_features[-1]
144200

145201
self.non_lin: type[nn.Module] = Factory(base_type=nn.Module).type_from_name(non_lin)
146202

147-
# Form layers dictionary
203+
# Create the GNN layers
148204
layer_dict = self.build_layer_dict()
149-
150-
# Compile network
151205
self.net = nn.Sequential(layer_dict)
152206

153207
@override(ShapeNormalizationBlock)
@@ -163,11 +217,12 @@ def normalized_forward(self, block_input: dict[str, torch.Tensor]) -> dict[str,
163217
assert edge_attr.ndim == self.in_num_dims[2]
164218

165219
assert node_feat.shape[-1] == self.input_features, \
166-
f"Feature dimension should fit: {node_feat.shape[-1]} vs {self.input_features}"
220+
f"Mismatch in node feature dimension: {node_feat.shape[-1]} vs expected {self.input_features}"
167221
assert edge_index.shape[-1] == edge_attr.shape[-2], \
168-
"Number of edges must be consistent"
222+
(f"Number of edges (E) must be consistent between edge_index: {edge_index.shape[-1]} "
223+
f"and edge_attr: {edge_attr.shape[-2]}")
169224

170-
# forward pass
225+
# Forward pass
171226
x = node_feat
172227
for layer in self.net:
173228
if isinstance(layer, GNNLayerPyG):
@@ -185,32 +240,37 @@ def normalized_forward(self, block_input: dict[str, torch.Tensor]) -> dict[str,
185240

186241
def build_layer_dict(self) -> OrderedDict:
187242
"""Compiles a block-specific dictionary of network layers.
188-
189-
This could be overwritten by derived layers (e.g. to get a 'BatchNormalizedConvolutionBlock').
190-
191-
:return: Ordered dictionary of torch modules [str, nn.Module].
243+
:return: Ordered dictionary of torch modules
192244
"""
193245
layer_dict = OrderedDict()
194246
in_feats = self.input_features
195247

196248
for layer_idx, out_feats in enumerate(self.hidden_features):
249+
197250
layer_dict[f'{self.gnn_type}_{layer_idx}'] = GNNLayerPyG(
198251
in_features=in_feats,
199252
out_features=out_feats,
200-
bias=self.bias,
201253
gnn_type=self.gnn_type,
254+
gnn_kwargs=self.gnn_kwargs
202255
)
203-
# Add activation function only for intermediate layers
256+
# Insert activation function after each hidden layer except the last
204257
if layer_idx < len(self.hidden_features) - 1:
205-
layer_dict[f'activation_{layer_idx}_{self.non_lin.__name__}'] = self.non_lin()
206-
in_feats = out_feats
258+
layer_name = f'activation_{layer_idx}_{self.non_lin.__name__}'
259+
layer_dict[layer_name] = self.non_lin()
260+
261+
if self.gnn_type == 'gat' and 'heads' in self.gnn_kwargs and \
262+
('concat' not in self.gnn_kwargs or self.gnn_kwargs['concat']):
263+
in_feats = out_feats * self.gnn_kwargs['heads']
264+
else:
265+
in_feats = out_feats
207266

208267
return layer_dict
209268

210269
def __repr__(self):
211-
txt = f'{self.__class__.__name__}'
212-
txt += f'({self.non_lin.__name__})'
213-
txt += '\n\t' + f'({self.input_features}->' + '->'.join([f'{h}' for h in self.hidden_features]) + ')'
214-
txt += f'\n\tBias: {self.bias}'
215-
txt += f'\n\tOut Shapes: {self.out_shapes()}'
216-
return txt
270+
txt = (
271+
f"{self.__class__.__name__}({self.non_lin.__name__})\n"
272+
f"\t({self.input_features}->" + "->".join([f"{h}" for h in self.hidden_features]) + ")\n"
273+
)
274+
txt += f"\n\tGNN kwargs: {self.gnn_kwargs}"
275+
txt += f"\n\tOut Shapes: {self.out_shapes()}"
276+
return txt

0 commit comments

Comments
 (0)