Skip to content

Commit 88b2579

Browse files
authored
Fix squeeze bug for length 1 array, and inconsistency on requirements and setup.py (#1540)
1 parent e878d48 commit 88b2579

File tree

4 files changed

+8
-2
lines changed

4 files changed

+8
-2
lines changed

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ tqdm>=4.33.0,<5.0.0
1919
# Internal models
2020
scikit-learn>=0.20.2,<0.22.0
2121
torch>=1.1.0,<1.2.0
22-
munkres==1.1.2
22+
munkres>=1.0.6
2323

2424
# LF dependency learning
2525
networkx>=2.2,<2.4

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
packages=find_packages(),
3636
include_package_data=True,
3737
install_requires=[
38-
"munkres==1.1.2",
38+
"munkres>=1.0.6",
3939
"numpy>=1.16.0,<2.0.0",
4040
"scipy>=1.2.0,<2.0.0",
4141
"pandas>=0.25.0,<0.26.0",

snorkel/utils/core.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,8 @@ def to_int_label_array(X: np.ndarray, flatten_vector: bool = True) -> np.ndarray
121121
# Correct shape
122122
if flatten_vector:
123123
X = X.squeeze()
124+
if X.ndim == 0:
125+
X = np.expand_dims(X, 0)
124126
if X.ndim != 1:
125127
raise ValueError("Input could not be converted to 1d np.array")
126128
return X

test/utils/test_core.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ def test_to_int_label_array(self):
2121
Y = to_int_label_array(X, flatten_vector=True)
2222
np.testing.assert_array_equal(Y, Y_expected)
2323

24+
Y = to_int_label_array(np.array([[1]]), flatten_vector=True)
25+
Y_expected = np.array([1])
26+
np.testing.assert_array_equal(Y, Y_expected)
27+
2428
Y = to_int_label_array(X, flatten_vector=False)
2529
Y_expected = np.array([[1], [0], [2]])
2630
np.testing.assert_array_equal(Y, Y_expected)

0 commit comments

Comments
 (0)