Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions snorkel/classification/multitask_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,17 +360,17 @@ def predict(
prob_dict_list[label_name].extend(prob_batch_dict[task_name])
gold_dict_list[label_name].extend(Y.cpu().numpy())

gold_dict: Dict[str, np.ndarray] = {}
prob_dict: Dict[str, np.ndarray] = {}
gold_dict: Dict[str, torch.Tensor] = {}
prob_dict: Dict[str, torch.Tensor] = {}

for task_name in gold_dict_list:
gold_dict[task_name] = np.array(gold_dict_list[task_name])
prob_dict[task_name] = np.array(prob_dict_list[task_name])
gold_dict[task_name] = torch.Tensor(np.array(gold_dict_list[task_name]))
prob_dict[task_name] = torch.Tensor(np.array(prob_dict_list[task_name]))

if return_preds:
pred_dict: Dict[str, np.ndarray] = defaultdict(list)
pred_dict: Dict[str, torch.Tensor] = defaultdict(np.ndarray)
for task_name, probs in prob_dict.items():
pred_dict[task_name] = probs_to_preds(probs)
pred_dict[task_name] = torch.Tensor(probs_to_preds(probs.numpy()))

results = {"golds": gold_dict, "probs": prob_dict}

Expand Down Expand Up @@ -431,9 +431,9 @@ def score(
# Score and record metrics for each set of predictions
for label_name, task_name in labels_to_tasks.items():
metric_scores = self.scorers[task_name].score(
golds=results["golds"][label_name],
preds=results["preds"][label_name],
probs=results["probs"][label_name],
golds=results["golds"][label_name].numpy(),
preds=results["preds"][label_name].numpy(),
probs=results["probs"][label_name].numpy(),
)

for metric_name, metric_value in metric_scores.items():
Expand Down
2 changes: 1 addition & 1 deletion snorkel/labeling/apply/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def _numpy_from_row_data(self, labels: List[RowData]) -> np.ndarray:
if self._use_recarray:
n_rows, _ = L.shape
dtype = [(name, np.int64) for name in self._lf_names]
recarray = np.recarray(n_rows, dtype=dtype)
recarray: np.ndarray = np.recarray(n_rows, dtype=dtype)
for idx, name in enumerate(self._lf_names):
recarray[name] = L[:, idx]

Expand Down
6 changes: 4 additions & 2 deletions snorkel/labeling/model/label_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class TrainConfig(Config):
lr_scheduler: str = "constant"
lr_scheduler_config: LRSchedulerConfig = LRSchedulerConfig() # type: ignore
prec_init: Union[float, List[float], np.ndarray, torch.Tensor] = 0.7
seed: int = np.random.randint(1e6)
seed: int = np.random.randint(1e6) # type: ignore
log_freq: int = 10
mu_eps: Optional[float] = None

Expand Down Expand Up @@ -557,7 +557,9 @@ def _loss_mu(self, l2: float = 0) -> torch.Tensor:
return loss_1 + loss_2 + self._loss_l2(l2=l2)

def _set_class_balance(
self, class_balance: Optional[List[float]], Y_dev: np.ndarray
self,
class_balance: Optional[List[float]],
Y_dev: Optional[np.ndarray] = None,
) -> None:
"""Set a prior for the class balance.

Expand Down
2 changes: 1 addition & 1 deletion snorkel/slicing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def add_slice_labels(

slice_names = S.dtype.names

Y_dict: Dict[str, np.ndarray] = dataloader.dataset.Y_dict # type: ignore
Y_dict: Dict[str, torch.Tensor] = dataloader.dataset.Y_dict # type: ignore
labels = Y_dict[base_task.name]

for slice_name in slice_names:
Expand Down
8 changes: 4 additions & 4 deletions snorkel/utils/core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import hashlib
from typing import Dict, List
from typing import Dict, List, Optional

import numpy as np

Expand Down Expand Up @@ -69,7 +69,7 @@ def probs_to_preds(
raise ValueError(
f"tie_break_policy={tie_break_policy} policy not recognized."
)
return Y_pred.astype(np.int)
return Y_pred.astype(np.int_)


def preds_to_probs(preds: np.ndarray, num_classes: int) -> np.ndarray:
Expand Down Expand Up @@ -129,7 +129,7 @@ def to_int_label_array(X: np.ndarray, flatten_vector: bool = True) -> np.ndarray


def filter_labels(
label_dict: Dict[str, np.ndarray], filter_dict: Dict[str, List[int]]
label_dict: Dict[str, Optional[np.ndarray]], filter_dict: Dict[str, List[int]]
) -> Dict[str, np.ndarray]:
"""Filter out examples from arrays based on specified labels to filter.

Expand Down Expand Up @@ -195,7 +195,7 @@ def _get_mask(label_array: np.ndarray, filter_values: List[int]) -> np.ndarray:
np.ndarray
A boolean mask indicating whether to keep (1) or filter (0) each example
"""
mask = np.ones_like(label_array).astype(bool)
mask: np.ndarray = np.ones_like(label_array).astype(bool)
for value in filter_values:
mask *= np.where(label_array != value, 1, 0).astype(bool)
return mask