Skip to content

Commit a31aefd

Browse files
author
EnliteAI Bot
committed
RL-2070: remove activation after last layer
(Issue RL-2070 - Implement PyG gnn block)
1 parent 80ec8d5 commit a31aefd

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

maze/perception/blocks/feed_forward/gnn_pyg.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Te
6363
6464
:param x: Note feature matrix, [n_nodes, in_features] or [B, n_nodes, in_features].
6565
:param edge_index: The graph connectivity in COO format, [2, E] or [B, 2, E].
66-
:param edge_attr: The edge attributes, [E, D] or [B, E, D]. D is the dimension of edge attribute.
66+
:param edge_attr: The edge attributes, [E, D] or [B, E, D]. D is the edge attribute dimension.
6767
:return: Output tensor.
6868
"""
6969
if x.dim() == 2: # Single graph (no batch)
@@ -192,7 +192,9 @@ def build_layer_dict(self) -> OrderedDict:
192192
bias=self.bias,
193193
gnn_type=self.gnn_type,
194194
)
195-
layer_dict[f'activation_{layer_idx}_{self.non_lin.__name__}'] = self.non_lin()
195+
# Add activation function only for intermediate layers
196+
if layer_idx < len(self.hidden_features) - 1:
197+
layer_dict[f'activation_{layer_idx}_{self.non_lin.__name__}'] = self.non_lin()
196198
in_feats = out_feats
197199

198200
return layer_dict

maze/test/perception/blocks/test_gnn_pyg.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def test_gnn_layer_forward(gnn_type, batch_size, bias):
4545
if batch_size is None:
4646
x = torch.randn(n_nodes, in_features)
4747
edge_index = torch.tensor([[0, 1, 2], [1, 2, 0]], dtype=torch.long)
48-
edge_attr = torch.randn(edge_index.shape[1]) # (E,) for GCN or even random
48+
edge_attr = torch.randn(edge_index.shape[1]) # (E,) for GCN
4949
else:
5050
x = torch.randn(batch_size, n_nodes, in_features)
5151
edge_index = []
@@ -70,7 +70,7 @@ def test_gnn_layer_forward(gnn_type, batch_size, bias):
7070
@pytest.mark.parametrize("gnn_type", SUPPORTED_GNNS)
7171
def test_gnn_block_forward(gnn_type: str):
7272
"""
73-
Test the forward pass of GNNBlockPyG with dummy data.
73+
Test the forward pass of GNNBlockPyG.
7474
"""
7575

7676
batch_size = 2
@@ -91,7 +91,6 @@ def test_gnn_block_forward(gnn_type: str):
9191

9292
node_feats = torch.randn(batch_size, n_nodes, in_features)
9393

94-
# Edge index shape expected: (batch_size, 2, E) => E = 10 here
9594
edge_index = []
9695
edge_attr = []
9796
for _ in range(batch_size):

0 commit comments

Comments
 (0)