Skip to content

Commit b2e8b93

Browse files
humzaiqbalAndreas Kodewitz
authored andcommitted
Fix linting (snorkel-team#1696)
* Fix linting * Reformat file * Fix failing UTs
1 parent 5fb2e23 commit b2e8b93

File tree

5 files changed

+19
-17
lines changed

5 files changed

+19
-17
lines changed

snorkel/classification/multitask_classifier.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -360,17 +360,17 @@ def predict(
360360
prob_dict_list[label_name].extend(prob_batch_dict[task_name])
361361
gold_dict_list[label_name].extend(Y.cpu().numpy())
362362

363-
gold_dict: Dict[str, np.ndarray] = {}
364-
prob_dict: Dict[str, np.ndarray] = {}
363+
gold_dict: Dict[str, torch.Tensor] = {}
364+
prob_dict: Dict[str, torch.Tensor] = {}
365365

366366
for task_name in gold_dict_list:
367-
gold_dict[task_name] = np.array(gold_dict_list[task_name])
368-
prob_dict[task_name] = np.array(prob_dict_list[task_name])
367+
gold_dict[task_name] = torch.Tensor(np.array(gold_dict_list[task_name]))
368+
prob_dict[task_name] = torch.Tensor(np.array(prob_dict_list[task_name]))
369369

370370
if return_preds:
371-
pred_dict: Dict[str, np.ndarray] = defaultdict(list)
371+
pred_dict: Dict[str, torch.Tensor] = defaultdict(np.ndarray)
372372
for task_name, probs in prob_dict.items():
373-
pred_dict[task_name] = probs_to_preds(probs)
373+
pred_dict[task_name] = torch.Tensor(probs_to_preds(probs.numpy()))
374374

375375
results = {"golds": gold_dict, "probs": prob_dict}
376376

@@ -431,9 +431,9 @@ def score(
431431
# Score and record metrics for each set of predictions
432432
for label_name, task_name in labels_to_tasks.items():
433433
metric_scores = self.scorers[task_name].score(
434-
golds=results["golds"][label_name],
435-
preds=results["preds"][label_name],
436-
probs=results["probs"][label_name],
434+
golds=results["golds"][label_name].numpy(),
435+
preds=results["preds"][label_name].numpy(),
436+
probs=results["probs"][label_name].numpy(),
437437
)
438438

439439
for metric_name, metric_value in metric_scores.items():

snorkel/labeling/apply/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def _numpy_from_row_data(self, labels: List[RowData]) -> np.ndarray:
6969
if self._use_recarray:
7070
n_rows, _ = L.shape
7171
dtype = [(name, np.int64) for name in self._lf_names]
72-
recarray = np.recarray(n_rows, dtype=dtype)
72+
recarray: np.ndarray = np.recarray(n_rows, dtype=dtype)
7373
for idx, name in enumerate(self._lf_names):
7474
recarray[name] = L[:, idx]
7575

snorkel/labeling/model/label_model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class TrainConfig(Config):
6060
lr_scheduler: str = "constant"
6161
lr_scheduler_config: LRSchedulerConfig = LRSchedulerConfig() # type: ignore
6262
prec_init: Union[float, List[float], np.ndarray, torch.Tensor] = 0.7
63-
seed: int = np.random.randint(1e6)
63+
seed: int = np.random.randint(1e6) # type: ignore
6464
log_freq: int = 10
6565
mu_eps: Optional[float] = None
6666

@@ -557,7 +557,9 @@ def _loss_mu(self, l2: float = 0) -> torch.Tensor:
557557
return loss_1 + loss_2 + self._loss_l2(l2=l2)
558558

559559
def _set_class_balance(
560-
self, class_balance: Optional[List[float]], Y_dev: np.ndarray
560+
self,
561+
class_balance: Optional[List[float]],
562+
Y_dev: Optional[np.ndarray] = None,
561563
) -> None:
562564
"""Set a prior for the class balance.
563565

snorkel/slicing/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def add_slice_labels(
3636

3737
slice_names = S.dtype.names
3838

39-
Y_dict: Dict[str, np.ndarray] = dataloader.dataset.Y_dict # type: ignore
39+
Y_dict: Dict[str, torch.Tensor] = dataloader.dataset.Y_dict # type: ignore
4040
labels = Y_dict[base_task.name]
4141

4242
for slice_name in slice_names:

snorkel/utils/core.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import hashlib
2-
from typing import Dict, List
2+
from typing import Dict, List, Optional
33

44
import numpy as np
55

@@ -69,7 +69,7 @@ def probs_to_preds(
6969
raise ValueError(
7070
f"tie_break_policy={tie_break_policy} policy not recognized."
7171
)
72-
return Y_pred.astype(np.int)
72+
return Y_pred.astype(np.int_)
7373

7474

7575
def preds_to_probs(preds: np.ndarray, num_classes: int) -> np.ndarray:
@@ -129,7 +129,7 @@ def to_int_label_array(X: np.ndarray, flatten_vector: bool = True) -> np.ndarray
129129

130130

131131
def filter_labels(
132-
label_dict: Dict[str, np.ndarray], filter_dict: Dict[str, List[int]]
132+
label_dict: Dict[str, Optional[np.ndarray]], filter_dict: Dict[str, List[int]]
133133
) -> Dict[str, np.ndarray]:
134134
"""Filter out examples from arrays based on specified labels to filter.
135135
@@ -195,7 +195,7 @@ def _get_mask(label_array: np.ndarray, filter_values: List[int]) -> np.ndarray:
195195
np.ndarray
196196
A boolean mask indicating whether to keep (1) or filter (0) each example
197197
"""
198-
mask = np.ones_like(label_array).astype(bool)
198+
mask: np.ndarray = np.ones_like(label_array).astype(bool)
199199
for value in filter_values:
200200
mask *= np.where(label_array != value, 1, 0).astype(bool)
201201
return mask

0 commit comments

Comments
 (0)