Skip to content
6 changes: 5 additions & 1 deletion snorkel/labeling/model/label_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,8 @@ def _execute_logging(self, loss: torch.Tensor) -> Metrics:

def _set_logger(self) -> None:
self.logger = Logger(self.train_config.log_freq)
if self.config.verbose:
logging.basicConfig(level=logging.INFO)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This appears to be changing the default logging the logging module, not for this specific class's logger. Can you modify just the logging level of self.logger instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did consider this. But, the logging has been implemented throughout the code at the root level. The logging inside fit() of LabelModel and log() of Logger are using module level logging function logging.info(). Therefore all the logs are generated by root level logger and the only way to set log level for LabelModel specifically is to change the CustomeLogger significantly. However, I don't think that is necessary. Even if I change the root logger level to INFO, it would not interfere with the logging flow. Because, before creating an INFO level log, there will always be a check for the verbose parameter.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Alright, this looks good to me then!


def _set_optimizer(self) -> None:
parameters = filter(lambda p: p.requires_grad, self.parameters())
Expand Down Expand Up @@ -881,6 +883,9 @@ def fit(
np.random.seed(self.train_config.seed)
torch.manual_seed(self.train_config.seed)

# Set Logger
self._set_logger()

L_shift = L_train + 1 # convert to {0, 1, ..., k}
if L_shift.max() > self.cardinality:
raise ValueError(
Expand Down Expand Up @@ -913,7 +918,6 @@ def fit(
self.to(self.config.device)

# Set training components
self._set_logger()
self._set_optimizer()
self._set_lr_scheduler()

Expand Down