Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ torch>=1.2.0,<2.0.0
munkres>=1.0.6

# LF dependency learning
networkx>=2.2,<2.4
networkx>=2.2,<2.6

# Model introspection tools
tensorboard>=1.14.0,<2.0.0
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
"scikit-learn>=0.20.2,<0.25.0",
"torch>=1.2.0,<2.0.0",
"tensorboard>=1.14.0,<2.0.0",
"networkx>=2.2,<2.4",
"networkx>=2.2,<2.6",
],
python_requires=">=3.6",
keywords="machine-learning ai weak-supervision",
Expand Down
4 changes: 2 additions & 2 deletions snorkel/labeling/model/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def get_clique_tree(nodes: Iterable[int], edges: List[Tuple[int, int]]) -> nx.Gr
Given a set of int nodes i and edges (i,j), returns a clique tree.

Clique tree is an object G for which:
- G.node[i]['members'] contains the set of original nodes in the ith
- G.nodes[i]['members'] contains the set of original nodes in the ith
maximal clique
- G[i][j]['members'] contains the set of original nodes in the seperator
set between maximal cliques i and j
Expand Down Expand Up @@ -46,7 +46,7 @@ def get_clique_tree(nodes: Iterable[int], edges: List[Tuple[int, int]]) -> nx.Gr
G2.add_node(i, members=c)
for i in G2.nodes:
for j in G2.nodes:
S = G2.node[i]["members"].intersection(G2.node[j]["members"])
S = G2.nodes[i]["members"].intersection(G2.nodes[j]["members"])
w = len(S)
if w > 0:
G2.add_edge(i, j, weight=w, members=S)
Expand Down
4 changes: 2 additions & 2 deletions snorkel/labeling/model/label_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def _get_augmented_label_matrix(
[
j
for j in self.c_tree.nodes()
if i in self.c_tree.node[j]["members"]
if i in self.c_tree.nodes[j]["members"]
]
),
)
Expand All @@ -211,7 +211,7 @@ def _get_augmented_label_matrix(
L_aug = np.copy(L_ind)
for item in chain(self.c_tree.nodes(), self.c_tree.edges()):
if isinstance(item, int):
C = self.c_tree.node[item]
C = self.c_tree.nodes[item]
elif isinstance(item, tuple):
C = self.c_tree[item[0]][item[1]]
else:
Expand Down
16 changes: 7 additions & 9 deletions test/labeling/model/test_label_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,15 +197,13 @@ def test_augmented_L_construction(self):
self.assertEqual(L_aug[i, j * k + L_shift[i, j] - 1], 1)

# Finally, check the clique entries
# Singleton clique 1
self.assertEqual(len(lm.c_tree.node[1]["members"]), 1)
j = lm.c_tree.node[1]["start_index"]
self.assertEqual(L_aug[0, j], 1)

# Singleton clique 2
self.assertEqual(len(lm.c_tree.node[2]["members"]), 1)
j = lm.c_tree.node[2]["start_index"]
self.assertEqual(L_aug[0, j + 1], 0)
for j in range(m):
node = lm.c_tree.nodes[i]
self.assertEqual(len(node["members"]), 1)
if 1 in node["members"]:
self.assertEqual(L_aug[0, node["start_index"]], 1)
if 2 in node["members"]:
self.assertEqual(L_aug[0, 1 + node["start_index"]], 0)

def test_conditional_probs(self):
L = np.array([[0, 1, 0], [0, 1, 0]])
Expand Down