Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion snorkel/labeling/model/label_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class TrainConfig(Config):
optimizer_config: OptimizerConfig = OptimizerConfig() # type: ignore
lr_scheduler: str = "constant"
lr_scheduler_config: LRSchedulerConfig = LRSchedulerConfig() # type: ignore
prec_init: float = 0.7
prec_init: Union[float, List[float], np.ndarray, torch.Tensor] = 0.7
seed: int = np.random.randint(1e6)
log_freq: int = 10
mu_eps: Optional[float] = None
Expand Down Expand Up @@ -280,6 +280,14 @@ def _init_params(self) -> None:
# Handle single values
if isinstance(self.train_config.prec_init, (int, float)):
self._prec_init = self.train_config.prec_init * torch.ones(self.m)
elif isinstance(self.train_config.prec_init, np.ndarray):
self._prec_init = torch.Tensor(self.train_config.prec_init)
elif isinstance(self.train_config.prec_init, list):
self._prec_init = torch.Tensor(self.train_config.prec_init)
elif not isinstance(self.train_config.prec_init, torch.Tensor):
raise TypeError(
f"prec_init is of type {type(self.train_config.prec_init)} which is not supported currently."
)
if self._prec_init.shape[0] != self.m:
raise ValueError(f"prec_init must have shape {self.m}.")

Expand Down
49 changes: 49 additions & 0 deletions test/labeling/model/test_label_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,55 @@ def test_mv_default(self):
label_model.predict(L), np.array([1, 1, 1])
)

def test_prec_init(self):
label_model = LabelModel(cardinality=2, verbose=False)
L = np.array([[-1, -1, 1], [-1, 1, -1], [0, -1, -1]])

# test without prec_init
label_model.fit(L_train=L, n_epochs=1000, seed=123)

# test with prec_init as float
prec_init = 0.6
label_model.fit(L_train=L, prec_init=prec_init, n_epochs=1000, seed=123)
label_model.predict(L)

# test with prec_init as int
prec_init = 1
label_model.fit(L_train=L, prec_init=prec_init, n_epochs=1000, seed=123)
label_model.predict(L)

# test with prec_init as list
prec_init = [0.1, 0.2, 0.3]
label_model.fit(L_train=L, prec_init=prec_init, n_epochs=1000, seed=123)
label_model.predict(L)

# test with prec_init as np.array
prec_init = np.array([0.1, 0.2, 0.3])
label_model.fit(L_train=L, prec_init=prec_init, n_epochs=1000, seed=123)
label_model.predict(L)

with self.assertRaisesRegex(
TypeError,
"prec_init is of type <class 'str'> which is not supported currently.",
):
# test with unsupported type (string)
prec_init = "skibidi bop mm dada"
label_model.fit(L_train=L, prec_init=prec_init, n_epochs=1000, seed=123)

with self.assertRaisesRegex(
ValueError, f"prec_init must have shape {L.shape[1]}."
):
# test with prec_init as list of wrong length (bigger)
prec_init = np.array([0.1, 0.2, 0.3, 0.4])
label_model.fit(L_train=L, prec_init=prec_init, n_epochs=1000, seed=123)

with self.assertRaisesRegex(
ValueError, f"prec_init must have shape {L.shape[1]}."
):
# test with prec_init as list of wrong length (smaller)
prec_init = np.array([0.1, 0.2])
label_model.fit(L_train=L, prec_init=prec_init, n_epochs=1000, seed=123)

def test_class_balance(self):
label_model = LabelModel(cardinality=2, verbose=False)
# Test class balance
Expand Down