Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
#### ESSENTIAL LIBRARIES

# General scientific computing
numpy>=1.16.5,<1.20.0

numpy>=1.16.5,<=1.22.3
scipy>=1.2.0,<2.0.0

# Data storage and function application
Expand All @@ -31,7 +32,7 @@ tensorboard>=2.0.0,<2.7.0

# spaCy (NLP)
spacy>=2.1.0,<3.0.0
blis>=0.3.0,<0.5.0
blis>=0.3.0,<=0.7.7

# Dask (parallelism)
dask[dataframe]>=2.1.0,<2.31.0
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
include_package_data=True,
install_requires=[
"munkres>=1.0.6",
"numpy>=1.16.5,<1.20.0",
"numpy>=1.16.5,<=1.22.3",
"scipy>=1.2.0,<2.0.0",
"pandas>=1.0.0,<2.0.0",
"tqdm>=4.33.0,<5.0.0",
Expand Down
10 changes: 8 additions & 2 deletions snorkel/analysis/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,19 @@ def metric_score(
preds = to_int_label_array(preds) if preds is not None else None

# Optionally filter out examples (e.g., abstain predictions or unknown labels)
label_dict = {"golds": golds, "preds": preds, "probs": probs}
label_dict: Dict[str, Optional[np.ndarray]] = {
"golds": golds,
"preds": preds,
"probs": probs,
}
if filter_dict:
if set(filter_dict.keys()).difference(set(label_dict.keys())):
raise ValueError(
"filter_dict must only include keys in ['golds', 'preds', 'probs']"
)
label_dict = filter_labels(label_dict, filter_dict)
# label_dict is overwritten from type Dict[str, Optional[np.ndarray]]
# to Dict[str, np.ndarray]
label_dict = filter_labels(label_dict, filter_dict) # type: ignore

# Confirm that required label sets are available
func, label_names = METRICS[metric]
Expand Down
2 changes: 1 addition & 1 deletion snorkel/classification/multitask_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def predict(
prob_dict[task_name] = torch.Tensor(np.array(prob_dict_list[task_name]))

if return_preds:
pred_dict: Dict[str, torch.Tensor] = defaultdict(np.ndarray)
pred_dict: Dict[str, torch.Tensor] = defaultdict(torch.Tensor)
for task_name, probs in prob_dict.items():
pred_dict[task_name] = torch.Tensor(probs_to_preds(probs.numpy()))

Expand Down
2 changes: 1 addition & 1 deletion snorkel/synthetic/synthetic_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def generate_simple_label_matrix(
Y = np.random.choice(cardinality, n)

# Generate the label matrix L
L = np.empty((n, m), dtype=int)
L: np.ndarray = np.empty((n, m), dtype=int)
for i in range(n):
for j in range(m):
L[i, j] = np.random.choice(cardinality + 1, p=P[j, :, Y[i]]) - 1
Expand Down
6 changes: 4 additions & 2 deletions snorkel/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,10 @@ def filter_labels(
"""
masks = []
for label_name, filter_values in filter_dict.items():
if label_dict[label_name] is not None:
masks.append(_get_mask(label_dict[label_name], filter_values))
label_array: Optional[np.ndarray] = label_dict.get(label_name)
if label_array is not None:
# _get_mask requires not-null input
masks.append(_get_mask(label_array, filter_values))
mask = (np.multiply(*masks) if len(masks) > 1 else masks[0]).squeeze()

filtered = {}
Expand Down