Skip to content

Commit 2be1fd4

Browse files
thatchAndreas Kodewitz
authored andcommitted
Compatibility with networkx 2.5 (snorkel-team#1645)
1 parent b9f319a commit 2be1fd4

File tree

5 files changed

+13
-15
lines changed

5 files changed

+13
-15
lines changed

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ torch>=1.2.0,<2.0.0
2222
munkres>=1.0.6
2323

2424
# LF dependency learning
25-
networkx>=2.2,<2.4
25+
networkx>=2.2,<2.6
2626

2727
# Model introspection tools
2828
tensorboard>=1.14.0,<2.0.0

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
"scikit-learn>=0.23.0,<0.25.0",
4444
"torch>=1.2.0,<2.0.0",
4545
"tensorboard>=1.14.0,<2.0.0",
46-
"networkx>=2.2,<2.4",
46+
"networkx>=2.2,<2.6",
4747
],
4848
python_requires=">=3.6",
4949
keywords="machine-learning ai weak-supervision",

snorkel/labeling/model/graph_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ def get_clique_tree(nodes: Iterable[int], edges: List[Tuple[int, int]]) -> nx.Gr
88
Given a set of int nodes i and edges (i,j), returns a clique tree.
99
1010
Clique tree is an object G for which:
11-
- G.node[i]['members'] contains the set of original nodes in the ith
11+
- G.nodes[i]['members'] contains the set of original nodes in the ith
1212
maximal clique
1313
- G[i][j]['members'] contains the set of original nodes in the seperator
1414
set between maximal cliques i and j
@@ -46,7 +46,7 @@ def get_clique_tree(nodes: Iterable[int], edges: List[Tuple[int, int]]) -> nx.Gr
4646
G2.add_node(i, members=c)
4747
for i in G2.nodes:
4848
for j in G2.nodes:
49-
S = G2.node[i]["members"].intersection(G2.node[j]["members"])
49+
S = G2.nodes[i]["members"].intersection(G2.nodes[j]["members"])
5050
w = len(S)
5151
if w > 0:
5252
G2.add_edge(i, j, weight=w, members=S)

snorkel/labeling/model/label_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def _get_augmented_label_matrix(
198198
[
199199
j
200200
for j in self.c_tree.nodes()
201-
if i in self.c_tree.node[j]["members"]
201+
if i in self.c_tree.nodes[j]["members"]
202202
]
203203
),
204204
)
@@ -212,7 +212,7 @@ def _get_augmented_label_matrix(
212212
L_aug = np.copy(L_ind)
213213
for item in chain(self.c_tree.nodes(), self.c_tree.edges()):
214214
if isinstance(item, int):
215-
C = self.c_tree.node[item]
215+
C = self.c_tree.nodes[item]
216216
elif isinstance(item, tuple):
217217
C = self.c_tree[item[0]][item[1]]
218218
else:

test/labeling/model/test_label_model.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -197,15 +197,13 @@ def test_augmented_L_construction(self):
197197
self.assertEqual(L_aug[i, j * k + L_shift[i, j] - 1], 1)
198198

199199
# Finally, check the clique entries
200-
# Singleton clique 1
201-
self.assertEqual(len(lm.c_tree.node[1]["members"]), 1)
202-
j = lm.c_tree.node[1]["start_index"]
203-
self.assertEqual(L_aug[0, j], 1)
204-
205-
# Singleton clique 2
206-
self.assertEqual(len(lm.c_tree.node[2]["members"]), 1)
207-
j = lm.c_tree.node[2]["start_index"]
208-
self.assertEqual(L_aug[0, j + 1], 0)
200+
for j in range(m):
201+
node = lm.c_tree.nodes[i]
202+
self.assertEqual(len(node["members"]), 1)
203+
if 1 in node["members"]:
204+
self.assertEqual(L_aug[0, node["start_index"]], 1)
205+
if 2 in node["members"]:
206+
self.assertEqual(L_aug[0, 1 + node["start_index"]], 0)
209207

210208
def test_conditional_probs(self):
211209
L = np.array([[0, 1, 0], [0, 1, 0]])

0 commit comments

Comments
 (0)