Skip to content

Commit 2df8f7d

Browse files
authored
Support m1 mac (#1698)
* update python version * py38 test now passes * fix mypy check errors * add more comment in code
1 parent cec6723 commit 2df8f7d

File tree

6 files changed

+16
-11
lines changed

6 files changed

+16
-11
lines changed

requirements.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
#### ESSENTIAL LIBRARIES
1010

1111
# General scientific computing
12-
numpy>=1.16.5,<1.22.0
12+
13+
numpy>=1.16.5,<=1.22.3
1314
scipy>=1.2.0,<2.0.0
1415

1516
# Data storage and function application
@@ -31,7 +32,7 @@ tensorboard>=2.0.0,<2.7.0
3132

3233
# spaCy (NLP)
3334
spacy>=2.1.0,<3.0.0
34-
blis>=0.3.0,<0.5.0
35+
blis>=0.3.0,<=0.7.7
3536

3637
# Dask (parallelism)
3738
dask[dataframe]>=2.1.0,<2.31.0

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
include_package_data=True,
3737
install_requires=[
3838
"munkres>=1.0.6",
39-
"numpy>=1.16.5,<1.22.0",
39+
"numpy>=1.16.5,<=1.22.3",
4040
"scipy>=1.2.0,<2.0.0",
4141
"pandas>=1.0.0,<2.0.0",
4242
"tqdm>=4.33.0,<5.0.0",

snorkel/analysis/metrics.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,21 +58,23 @@ def metric_score(
5858
preds = to_int_label_array(preds) if preds is not None else None
5959

6060
# Optionally filter out examples (e.g., abstain predictions or unknown labels)
61-
label_dict = {"golds": golds, "preds": preds, "probs": probs}
61+
label_dict: Dict[str, Optional[np.ndarray]] = {"golds": golds, "preds": preds, "probs": probs}
6262
if filter_dict:
6363
if set(filter_dict.keys()).difference(set(label_dict.keys())):
6464
raise ValueError(
6565
"filter_dict must only include keys in ['golds', 'preds', 'probs']"
6666
)
67-
label_dict = filter_labels(label_dict, filter_dict)
67+
# Reassign filtered label_dict to a new variable to avoid
68+
# mypy error regarding change variable of invariant type
69+
label_dict_filtered: Dict[str, np.ndarray] = filter_labels(label_dict, filter_dict)
6870

6971
# Confirm that required label sets are available
7072
func, label_names = METRICS[metric]
7173
for label_name in label_names:
72-
if label_dict[label_name] is None:
74+
if label_dict_filtered[label_name] is None:
7375
raise ValueError(f"Metric {metric} requires access to {label_name}.")
7476

75-
label_sets = [label_dict[label_name] for label_name in label_names]
77+
label_sets = [label_dict_filtered[label_name] for label_name in label_names]
7678
return func(*label_sets, **kwargs)
7779

7880

snorkel/classification/multitask_classifier.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ def predict(
368368
prob_dict[task_name] = torch.Tensor(np.array(prob_dict_list[task_name]))
369369

370370
if return_preds:
371-
pred_dict: Dict[str, torch.Tensor] = defaultdict(np.ndarray)
371+
pred_dict: Dict[str, torch.Tensor] = defaultdict(torch.Tensor)
372372
for task_name, probs in prob_dict.items():
373373
pred_dict[task_name] = torch.Tensor(probs_to_preds(probs.numpy()))
374374

snorkel/synthetic/synthetic_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def generate_simple_label_matrix(
5252
Y = np.random.choice(cardinality, n)
5353

5454
# Generate the label matrix L
55-
L = np.empty((n, m), dtype=int)
55+
L: np.ndarray = np.empty((n, m), dtype=int)
5656
for i in range(n):
5757
for j in range(m):
5858
L[i, j] = np.random.choice(cardinality + 1, p=P[j, :, Y[i]]) - 1

snorkel/utils/core.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,10 @@ def filter_labels(
170170
"""
171171
masks = []
172172
for label_name, filter_values in filter_dict.items():
173-
if label_dict[label_name] is not None:
174-
masks.append(_get_mask(label_dict[label_name], filter_values))
173+
label_array: Optional[np.ndarray] = label_dict.get(label_name)
174+
if label_array is not None:
175+
# _get_mask requires not-null input
176+
masks.append(_get_mask(label_array, filter_values))
175177
mask = (np.multiply(*masks) if len(masks) > 1 else masks[0]).squeeze()
176178

177179
filtered = {}

0 commit comments

Comments
 (0)