-
Notifications
You must be signed in to change notification settings - Fork 858
Closed
Labels
Description
Issue description
I wanted to specify the reliability of different weakly supervised LFs to the LabelModel
. I noticed that we can specify these as precision priors through the prec_init
parameter to the LabelModel.fit()
method
However, prec_init
parameter seems to only accept scalar values. Providing an array with precision value of each LF throws the following error:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-18-e1744ea2ad84> in <module>
7
8 lm = snorkel.labeling.LabelModel()
----> 9 lm.fit(L_train, prec_init=prec_init, n_epochs=1000, lr=0.001, log_freq=1, seed=123)
10 y_pred = lm.predict(L_train)
11
~\AppData\Local\Continuum\anaconda3\envs\snorkel\lib\site-packages\snorkel\labeling\model\label_model.py in fit(self, L_train, Y_dev, class_balance, **kwargs)
749 logging.info("Computing O...")
750 self._generate_O(L_shift)
--> 751 self._init_params()
752
753 # Estimate \mu
~\AppData\Local\Continuum\anaconda3\envs\snorkel\lib\site-packages\snorkel\labeling\model\label_model.py in _init_params(self)
260 if isinstance(self.train_config.prec_init, (int, float)):
261 self._prec_init = self.train_config.prec_init * torch.ones(self.m)
--> 262 if self._prec_init.shape[0] != self.m:
263 raise ValueError(f"prec_init must have shape {self.m}.")
264
~\AppData\Local\Continuum\anaconda3\envs\snorkel\lib\site-packages\torch\nn\modules\module.py in __getattr__(self, name)
537 return modules[name]
538 raise AttributeError("'{}' object has no attribute '{}'".format(
--> 539 type(self).__name__, name))
540
541 def __setattr__(self, name, value):
AttributeError: 'LabelModel' object has no attribute '_prec_init'
Code example/repro steps
Below is a short code snippet to reproduce this error within a jupyter notebook:
import numpy as np
import snorkel.labeling
num_lfs = 10
num_samples = 10**6
L_train = np.random.randint(-1, 2, size=(num_samples, num_lfs), dtype=np.int8)
prec_init = np.random.rand(10)
lm = snorkel.labeling.LabelModel()
lm.fit(L_train, prec_init=prec_init, n_epochs=1000, lr=0.001, log_freq=1, seed=123)
y_pred = lm.predict(L_train)
df_lf_summary = snorkel.labeling.LFAnalysis(L_train).lf_summary(Y = y_pred, est_weights=lm.get_weights())
display(df_lf_summary)
System info
- How you installed Snorkel (conda, pip, source): conda
- OS: Windows 10
- Python version: 3.7.4
- Snorkel version: 0.9.0