Skip to content

Commit 1adc7a8

Browse files
authored
Upgrade pytorch requirement to >=1.2.0 (#1565)
Description of proposed changes * Update dev + setup.py requirements for torch * Use native torch.utils.tensorboard and remove tensorboardX dep * Moved lr_scheduler.step() after optimizer.step() * Replace deprecated uses of .byte() with .bool() Related issue(s) * Fixes #1558 Test plan * All unit tests pass with torch==1.2.0, 1.3.0, 1.4.0 * tox -e {complex,spark} passes with torch==1.2.0, 1.3.0, 1.4.0 * Updated Slicing convergence test for longer fit. * Tested Tensorboard output using tutorials
1 parent b1e811c commit 1adc7a8

File tree

9 files changed

+17
-19
lines changed

9 files changed

+17
-19
lines changed

requirements.txt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,14 @@ tqdm>=4.33.0,<5.0.0
1818

1919
# Internal models
2020
scikit-learn>=0.20.2,<0.22.0
21-
torch>=1.1.0,<1.2.0
21+
torch>=1.2.0,<2.0.0
2222
munkres>=1.0.6
2323

2424
# LF dependency learning
2525
networkx>=2.2,<2.4
2626

2727
# Model introspection tools
28-
tensorboardX>=1.6,<2.0
29-
28+
tensorboard>=1.14.0,<2.0.0
3029

3130
#### EXTRA/TEST LIBRARIES
3231

setup.cfg

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,6 @@ ignore_missing_imports = True
101101
[mypy-spacy]
102102
ignore_missing_imports = True
103103

104-
[mypy-tensorboardX]
105-
ignore_missing_imports = True
106-
107104
[mypy-tqdm]
108105
ignore_missing_imports = True
109106

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@
4141
"pandas>=0.25.0,<0.26.0",
4242
"tqdm>=4.33.0,<5.0.0",
4343
"scikit-learn>=0.20.2,<0.22.0",
44-
"torch>=1.1.0,<1.2.0",
44+
"torch>=1.2.0,<2.0.0",
45+
"tensorboard>=1.14.0,<2.0.0",
4546
"networkx>=2.2,<2.4",
46-
"tensorboardX>=1.6,<2.0",
4747
],
4848
python_requires=">=3.6",
4949
keywords="machine-learning ai weak-supervision",

snorkel/classification/training/loggers/tensorboard_writer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Any
22

3-
from tensorboardX import SummaryWriter
3+
from torch.utils.tensorboard import SummaryWriter
44

55
from snorkel.types import Config
66

@@ -20,7 +20,7 @@ class TensorBoardWriter(LogWriter):
2020
Attributes
2121
----------
2222
writer
23-
tensorboardX ``SummaryWriter`` for logging and visualization
23+
``SummaryWriter`` for logging and visualization
2424
"""
2525

2626
def __init__(self, **kwargs: Any) -> None:

snorkel/classification/training/trainer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -197,9 +197,6 @@ def fit(
197197
total_batch_num = epoch_num * self.n_batches_per_epoch + batch_num
198198
batch_size = len(next(iter(Y_dict.values())))
199199

200-
# Update lr using lr scheduler
201-
self._update_lr_scheduler(total_batch_num)
202-
203200
# Set gradients of all model parameters to zero
204201
self.optimizer.zero_grad()
205202

@@ -240,6 +237,9 @@ def fit(
240237
# Update the parameters
241238
self.optimizer.step()
242239

240+
# Update lr using lr scheduler
241+
self._update_lr_scheduler(total_batch_num)
242+
243243
# Update metrics
244244
self.metrics.update(self._logging(model, dataloaders, batch_size))
245245

snorkel/labeling/model/label_model.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def _get_augmented_label_matrix(
227227

228228
def _build_mask(self) -> None:
229229
"""Build mask applied to O^{-1}, O for the matrix approx constraint."""
230-
self.mask = torch.ones(self.d, self.d).byte()
230+
self.mask = torch.ones(self.d, self.d).bool()
231231
for ci in self.c_data.values():
232232
si = ci.start_index
233233
ei = ci.end_index
@@ -296,7 +296,9 @@ def _init_params(self) -> None:
296296
self.mu_init[idx, y] += mu_init
297297

298298
# Initialize randomly based on self.mu_init
299-
self.mu = nn.Parameter(self.mu_init.clone() * np.random.random()).float()
299+
self.mu = nn.Parameter( # type: ignore
300+
self.mu_init.clone() * np.random.random()
301+
).float()
300302

301303
# Build the mask over O^{-1}
302304
self._build_mask()
@@ -784,7 +786,7 @@ def _break_col_permutation_symmetry(self) -> None:
784786
Z[group[i], group[j]] = 1.0
785787

786788
# Set mu according to permutation
787-
self.mu = nn.Parameter(
789+
self.mu = nn.Parameter( # type: ignore
788790
torch.Tensor(mu @ Z).to(self.config.device) # type: ignore
789791
)
790792

snorkel/slicing/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def add_slice_labels(
4545

4646
# Mask out "inactive" pred_labels as specified by ind_labels
4747
pred_labels = labels.clone()
48-
pred_labels[~ind_labels.byte()] = -1
48+
pred_labels[~ind_labels.bool()] = -1
4949

5050
ind_task_name = f"{base_task.name}_slice:{slice_name}_ind"
5151
pred_task_name = f"{base_task.name}_slice:{slice_name}_pred"

test/classification/training/loggers/test_tensorboard_writer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def tearDown(self):
2222

2323
def test_tensorboard_writer(self):
2424
# Note: this just tests API calls. We rely on
25-
# tensorboardX's unit tests for correctness.
25+
# tensorboard's unit tests for correctness.
2626
run_name = "my_run"
2727
config = TempConfig(b="bar")
2828
writer = TensorBoardWriter(run_name=run_name, log_dir=self.test_dir)

test/slicing/test_convergence.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def test_performance(self):
149149

150150
# Train
151151
# NOTE: Needs more epochs to convergence with more heads
152-
trainer = Trainer(lr=0.001, n_epochs=60, progress_bar=False)
152+
trainer = Trainer(lr=0.001, n_epochs=65, progress_bar=False)
153153
trainer.fit(model, dataloaders)
154154
scores = model.score(dataloaders)
155155

0 commit comments

Comments
 (0)