File tree Expand file tree Collapse file tree 9 files changed +17
-19
lines changed
classification/training/loggers Expand file tree Collapse file tree 9 files changed +17
-19
lines changed Original file line number Diff line number Diff line change @@ -18,15 +18,14 @@ tqdm>=4.33.0,<5.0.0
18
18
19
19
# Internal models
20
20
scikit-learn >= 0.20.2 ,< 0.22.0
21
- torch >= 1.1 .0 ,< 1.2 .0
21
+ torch >= 1.2 .0 ,< 2.0 .0
22
22
munkres >= 1.0.6
23
23
24
24
# LF dependency learning
25
25
networkx >= 2.2 ,< 2.4
26
26
27
27
# Model introspection tools
28
- tensorboardX >= 1.6 ,< 2.0
29
-
28
+ tensorboard >= 1.14.0 ,< 2.0.0
30
29
31
30
#### EXTRA/TEST LIBRARIES
32
31
Original file line number Diff line number Diff line change @@ -101,9 +101,6 @@ ignore_missing_imports = True
101
101
[mypy-spacy]
102
102
ignore_missing_imports = True
103
103
104
- [mypy-tensorboardX]
105
- ignore_missing_imports = True
106
-
107
104
[mypy-tqdm]
108
105
ignore_missing_imports = True
109
106
Original file line number Diff line number Diff line change 41
41
"pandas>=0.25.0,<0.26.0" ,
42
42
"tqdm>=4.33.0,<5.0.0" ,
43
43
"scikit-learn>=0.20.2,<0.22.0" ,
44
- "torch>=1.1.0,<1.2.0" ,
44
+ "torch>=1.2.0,<2.0.0" ,
45
+ "tensorboard>=1.14.0,<2.0.0" ,
45
46
"networkx>=2.2,<2.4" ,
46
- "tensorboardX>=1.6,<2.0" ,
47
47
],
48
48
python_requires = ">=3.6" ,
49
49
keywords = "machine-learning ai weak-supervision" ,
Original file line number Diff line number Diff line change 1
1
from typing import Any
2
2
3
- from tensorboardX import SummaryWriter
3
+ from torch . utils . tensorboard import SummaryWriter
4
4
5
5
from snorkel .types import Config
6
6
@@ -20,7 +20,7 @@ class TensorBoardWriter(LogWriter):
20
20
Attributes
21
21
----------
22
22
writer
23
- tensorboardX ``SummaryWriter`` for logging and visualization
23
+ ``SummaryWriter`` for logging and visualization
24
24
"""
25
25
26
26
def __init__ (self , ** kwargs : Any ) -> None :
Original file line number Diff line number Diff line change @@ -197,9 +197,6 @@ def fit(
197
197
total_batch_num = epoch_num * self .n_batches_per_epoch + batch_num
198
198
batch_size = len (next (iter (Y_dict .values ())))
199
199
200
- # Update lr using lr scheduler
201
- self ._update_lr_scheduler (total_batch_num )
202
-
203
200
# Set gradients of all model parameters to zero
204
201
self .optimizer .zero_grad ()
205
202
@@ -240,6 +237,9 @@ def fit(
240
237
# Update the parameters
241
238
self .optimizer .step ()
242
239
240
+ # Update lr using lr scheduler
241
+ self ._update_lr_scheduler (total_batch_num )
242
+
243
243
# Update metrics
244
244
self .metrics .update (self ._logging (model , dataloaders , batch_size ))
245
245
Original file line number Diff line number Diff line change @@ -227,7 +227,7 @@ def _get_augmented_label_matrix(
227
227
228
228
def _build_mask (self ) -> None :
229
229
"""Build mask applied to O^{-1}, O for the matrix approx constraint."""
230
- self .mask = torch .ones (self .d , self .d ).byte ()
230
+ self .mask = torch .ones (self .d , self .d ).bool ()
231
231
for ci in self .c_data .values ():
232
232
si = ci .start_index
233
233
ei = ci .end_index
@@ -296,7 +296,9 @@ def _init_params(self) -> None:
296
296
self .mu_init [idx , y ] += mu_init
297
297
298
298
# Initialize randomly based on self.mu_init
299
- self .mu = nn .Parameter (self .mu_init .clone () * np .random .random ()).float ()
299
+ self .mu = nn .Parameter ( # type: ignore
300
+ self .mu_init .clone () * np .random .random ()
301
+ ).float ()
300
302
301
303
# Build the mask over O^{-1}
302
304
self ._build_mask ()
@@ -784,7 +786,7 @@ def _break_col_permutation_symmetry(self) -> None:
784
786
Z [group [i ], group [j ]] = 1.0
785
787
786
788
# Set mu according to permutation
787
- self .mu = nn .Parameter (
789
+ self .mu = nn .Parameter ( # type: ignore
788
790
torch .Tensor (mu @ Z ).to (self .config .device ) # type: ignore
789
791
)
790
792
Original file line number Diff line number Diff line change @@ -45,7 +45,7 @@ def add_slice_labels(
45
45
46
46
# Mask out "inactive" pred_labels as specified by ind_labels
47
47
pred_labels = labels .clone ()
48
- pred_labels [~ ind_labels .byte ()] = - 1
48
+ pred_labels [~ ind_labels .bool ()] = - 1
49
49
50
50
ind_task_name = f"{ base_task .name } _slice:{ slice_name } _ind"
51
51
pred_task_name = f"{ base_task .name } _slice:{ slice_name } _pred"
Original file line number Diff line number Diff line change @@ -22,7 +22,7 @@ def tearDown(self):
22
22
23
23
def test_tensorboard_writer (self ):
24
24
# Note: this just tests API calls. We rely on
25
- # tensorboardX 's unit tests for correctness.
25
+ # tensorboard 's unit tests for correctness.
26
26
run_name = "my_run"
27
27
config = TempConfig (b = "bar" )
28
28
writer = TensorBoardWriter (run_name = run_name , log_dir = self .test_dir )
Original file line number Diff line number Diff line change @@ -149,7 +149,7 @@ def test_performance(self):
149
149
150
150
# Train
151
151
# NOTE: Needs more epochs to convergence with more heads
152
- trainer = Trainer (lr = 0.001 , n_epochs = 60 , progress_bar = False )
152
+ trainer = Trainer (lr = 0.001 , n_epochs = 65 , progress_bar = False )
153
153
trainer .fit (model , dataloaders )
154
154
scores = model .score (dataloaders )
155
155
You can’t perform that action at this time.
0 commit comments