Skip to content

Commit 8abd274

Browse files
authored
Introduce multiprocess logger (#337)
1 parent b05d483 commit 8abd274

File tree

5 files changed

+69
-12
lines changed

5 files changed

+69
-12
lines changed

examples/by_feature/tracking.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
import argparse
16-
import logging
1716
import os
1817

1918
import torch
@@ -30,9 +29,6 @@
3029
)
3130

3231

33-
logger = logging.getLogger(__name__)
34-
35-
3632
########################################################################
3733
# This is a fully working simple example to use Accelerate,
3834
# specifically showcasing the experiment tracking capability,

src/accelerate/accelerator.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state
2727
from .data_loader import prepare_data_loader
2828
from .kwargs_handlers import DistributedDataParallelKwargs, GradScalerKwargs, InitProcessGroupKwargs, KwargsHandler
29+
from .logging import get_logger
2930
from .optimizer import AcceleratedOptimizer
3031
from .scheduler import AcceleratedScheduler
3132
from .state import AcceleratorState, DistributedType, is_deepspeed_available
@@ -52,10 +53,7 @@
5253

5354
from .deepspeed_utils import DeepSpeedEngineWrapper, DeepSpeedOptimizerWrapper
5455

55-
import logging
56-
57-
58-
logger = logging.getLogger(__name__)
56+
logger = get_logger(__name__)
5957

6058

6159
class Accelerator:

src/accelerate/checkpointing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@
2828
if is_tpu_available():
2929
import torch_xla.core.xla_model as xm
3030

31-
import logging
31+
from .logging import get_logger
3232

3333

34-
logger = logging.getLogger(__name__)
34+
logger = get_logger(__name__)
3535

3636

3737
def save_accelerator_state(

src/accelerate/logging.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Copyright 2022 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import logging
16+
17+
from .state import AcceleratorState
18+
19+
20+
class MultiProcessAdapter(logging.LoggerAdapter):
21+
"""
22+
An adapter to assist with logging in multiprocess.
23+
24+
`log` takes in an additional `main_process_only` kwarg, which dictates whether it should be called on all processes
25+
or only the main executed one. Default is `main_process_only=True`.
26+
"""
27+
28+
@staticmethod
29+
def _should_log(main_process_only):
30+
"Check if log should be performed"
31+
return not main_process_only or (main_process_only and AcceleratorState().local_process_index == 0)
32+
33+
def log(self, level, msg, *args, **kwargs):
34+
"""
35+
Delegates logger call after checking if we should log.
36+
37+
Accepts a new kwarg of `main_process_only`, which will dictate whether it will be logged across all processes
38+
or only the main executed one. Default is `True` if not passed
39+
"""
40+
main_process_only = kwargs.pop("main_process_only", True)
41+
if self.isEnabledFor(level) and self._should_log(main_process_only):
42+
msg, kwargs = self.process(msg, kwargs)
43+
self.logger.log(level, msg, *args, **kwargs)
44+
45+
46+
def get_logger(name: str):
47+
"""
48+
Returns a `logging.Logger` for `name` that can handle multiprocessing.
49+
50+
If a log should be called on all processes, pass `main_process_only=False`
51+
52+
E.g.
53+
```python
54+
logger.info("My log", main_process_only=False)
55+
logger.debug("My log", main_process_only=False)
56+
```
57+
58+
Args:
59+
name (`str`):
60+
The name for the logger, such as `__file__`
61+
"""
62+
logger = logging.getLogger(name)
63+
return MultiProcessAdapter(logger, {})

src/accelerate/tracking.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
# Expectation:
1616
# Provide a project dir name, then each type of logger gets stored in project/{`logging_dir`}
1717

18-
import logging
1918
import os
2019
from abc import ABCMeta, abstractmethod, abstractproperty
2120
from typing import List, Optional, Union
2221

22+
from .logging import get_logger
2323
from .utils import LoggerType, is_comet_ml_available, is_tensorboard_available, is_wandb_available
2424

2525

@@ -41,7 +41,7 @@
4141
_available_trackers.append(LoggerType.COMETML)
4242

4343

44-
logger = logging.getLogger(__name__)
44+
logger = get_logger(__name__)
4545

4646

4747
def get_available_trackers():

0 commit comments

Comments
 (0)