Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,14 @@ tqdm>=4.33.0,<5.0.0

# Internal models
scikit-learn>=0.20.2,<0.22.0
torch>=1.1.0,<1.2.0
torch>=1.2.0,<2.0.0
munkres>=1.0.6

# LF dependency learning
networkx>=2.2,<2.4

# Model introspection tools
tensorboardX>=1.6,<2.0

tensorboard>=1.14.0,<2.0.0

#### EXTRA/TEST LIBRARIES

Expand Down
3 changes: 0 additions & 3 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,6 @@ ignore_missing_imports = True
[mypy-spacy]
ignore_missing_imports = True

[mypy-tensorboardX]
ignore_missing_imports = True

[mypy-tqdm]
ignore_missing_imports = True

Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@
"pandas>=0.25.0,<0.26.0",
"tqdm>=4.33.0,<5.0.0",
"scikit-learn>=0.20.2,<0.22.0",
"torch>=1.1.0,<1.2.0",
"torch>=1.2.0,<2.0.0",
"tensorboard>=1.14.0,<2.0.0",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how did we decide on this range?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got an import error that looked like this:

ImportError: TensorBoard logging requires TensorBoard with Python summary writer installed. This should be available in 1.14 or above

"networkx>=2.2,<2.4",
"tensorboardX>=1.6,<2.0",
],
python_requires=">=3.6",
keywords="machine-learning ai weak-supervision",
Expand Down
4 changes: 2 additions & 2 deletions snorkel/classification/training/loggers/tensorboard_writer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any

from tensorboardX import SummaryWriter
from torch.utils.tensorboard import SummaryWriter

from snorkel.types import Config

Expand All @@ -20,7 +20,7 @@ class TensorBoardWriter(LogWriter):
Attributes
----------
writer
tensorboardX ``SummaryWriter`` for logging and visualization
``SummaryWriter`` for logging and visualization
"""

def __init__(self, **kwargs: Any) -> None:
Expand Down
6 changes: 3 additions & 3 deletions snorkel/classification/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,6 @@ def fit(
total_batch_num = epoch_num * self.n_batches_per_epoch + batch_num
batch_size = len(next(iter(Y_dict.values())))

# Update lr using lr scheduler
self._update_lr_scheduler(total_batch_num)

# Set gradients of all model parameters to zero
self.optimizer.zero_grad()

Expand Down Expand Up @@ -240,6 +237,9 @@ def fit(
# Update the parameters
self.optimizer.step()

# Update lr using lr scheduler
self._update_lr_scheduler(total_batch_num)

# Update metrics
self.metrics.update(self._logging(model, dataloaders, batch_size))

Expand Down
8 changes: 5 additions & 3 deletions snorkel/labeling/model/label_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def _get_augmented_label_matrix(

def _build_mask(self) -> None:
"""Build mask applied to O^{-1}, O for the matrix approx constraint."""
self.mask = torch.ones(self.d, self.d).byte()
self.mask = torch.ones(self.d, self.d).bool()
for ci in self.c_data.values():
si = ci.start_index
ei = ci.end_index
Expand Down Expand Up @@ -296,7 +296,9 @@ def _init_params(self) -> None:
self.mu_init[idx, y] += mu_init

# Initialize randomly based on self.mu_init
self.mu = nn.Parameter(self.mu_init.clone() * np.random.random()).float()
self.mu = nn.Parameter( # type: ignore
self.mu_init.clone() * np.random.random()
).float()

# Build the mask over O^{-1}
self._build_mask()
Expand Down Expand Up @@ -784,7 +786,7 @@ def _break_col_permutation_symmetry(self) -> None:
Z[group[i], group[j]] = 1.0

# Set mu according to permutation
self.mu = nn.Parameter(
self.mu = nn.Parameter( # type: ignore
torch.Tensor(mu @ Z).to(self.config.device) # type: ignore
)

Expand Down
2 changes: 1 addition & 1 deletion snorkel/slicing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def add_slice_labels(

# Mask out "inactive" pred_labels as specified by ind_labels
pred_labels = labels.clone()
pred_labels[~ind_labels.byte()] = -1
pred_labels[~ind_labels.bool()] = -1

ind_task_name = f"{base_task.name}_slice:{slice_name}_ind"
pred_task_name = f"{base_task.name}_slice:{slice_name}_pred"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def tearDown(self):

def test_tensorboard_writer(self):
# Note: this just tests API calls. We rely on
# tensorboardX's unit tests for correctness.
# tensorboard's unit tests for correctness.
run_name = "my_run"
config = TempConfig(b="bar")
writer = TensorBoardWriter(run_name=run_name, log_dir=self.test_dir)
Expand Down
2 changes: 1 addition & 1 deletion test/slicing/test_convergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def test_performance(self):

# Train
# NOTE: Needs more epochs to convergence with more heads
trainer = Trainer(lr=0.001, n_epochs=60, progress_bar=False)
trainer = Trainer(lr=0.001, n_epochs=65, progress_bar=False)
trainer.fit(model, dataloaders)
scores = model.score(dataloaders)

Expand Down