Skip to content

Commit abb74da

Browse files
muellerzrstevhliu
authored andcommitted
Update no_trainer examples to use new logger (huggingface#17044)
* Propagate and fix imports
1 parent c190e2b commit abb74da

File tree

13 files changed

+38
-76
lines changed

13 files changed

+38
-76
lines changed

examples/pytorch/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,10 +167,10 @@ python xla_spawn.py --num_cores 8 \
167167

168168
Most PyTorch example scripts have a version using the [🤗 Accelerate](https://github.com/huggingface/accelerate) library
169169
that exposes the training loop so it's easy for you to customize or tweak them to your needs. They all require you to
170-
install `accelerate` with
170+
install `accelerate` with the latest development version
171171

172172
```bash
173-
pip install accelerate
173+
pip install git+https://github.com/huggingface/accelerate
174174
```
175175

176176
Then you can easily launch any of the scripts by running

examples/pytorch/image-classification/run_image_classification_no_trainer.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737

3838
import transformers
3939
from accelerate import Accelerator
40+
from accelerate.logging import get_logger
4041
from accelerate.utils import set_seed
4142
from huggingface_hub import Repository
4243
from transformers import (
@@ -50,7 +51,7 @@
5051
from transformers.utils.versions import require_version
5152

5253

53-
logger = logging.getLogger(__name__)
54+
logger = get_logger(__name__)
5455

5556
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
5657

@@ -188,11 +189,7 @@ def main():
188189
datefmt="%m/%d/%Y %H:%M:%S",
189190
level=logging.INFO,
190191
)
191-
logger.info(accelerator.state)
192-
193-
# Setup logging, we only want one process per machine to log things on the screen.
194-
# accelerator.is_local_main_process is only True for one process per machine.
195-
logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
192+
logger.info(accelerator.state, main_process_only=False)
196193
if accelerator.is_local_main_process:
197194
datasets.utils.logging.set_verbosity_warning()
198195
transformers.utils.logging.set_verbosity_info()

examples/pytorch/language-modeling/run_clm_no_trainer.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939

4040
import transformers
4141
from accelerate import Accelerator, DistributedType
42+
from accelerate.logging import get_logger
4243
from accelerate.utils import set_seed
4344
from huggingface_hub import Repository
4445
from transformers import (
@@ -56,7 +57,7 @@
5657
from transformers.utils.versions import require_version
5758

5859

59-
logger = logging.getLogger(__name__)
60+
logger = get_logger(__name__)
6061

6162
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
6263

@@ -234,11 +235,7 @@ def main():
234235
datefmt="%m/%d/%Y %H:%M:%S",
235236
level=logging.INFO,
236237
)
237-
logger.info(accelerator.state)
238-
239-
# Setup logging, we only want one process per machine to log things on the screen.
240-
# accelerator.is_local_main_process is only True for one process per machine.
241-
logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
238+
logger.info(accelerator.state, main_process_only=False)
242239
if accelerator.is_local_main_process:
243240
datasets.utils.logging.set_verbosity_warning()
244241
transformers.utils.logging.set_verbosity_info()

examples/pytorch/language-modeling/run_mlm_no_trainer.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939

4040
import transformers
4141
from accelerate import Accelerator, DistributedType
42+
from accelerate.logging import get_logger
4243
from accelerate.utils import set_seed
4344
from huggingface_hub import Repository
4445
from transformers import (
@@ -56,7 +57,7 @@
5657
from transformers.utils.versions import require_version
5758

5859

59-
logger = logging.getLogger(__name__)
60+
logger = get_logger(__name__)
6061
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
6162
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
6263
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
@@ -245,11 +246,7 @@ def main():
245246
datefmt="%m/%d/%Y %H:%M:%S",
246247
level=logging.INFO,
247248
)
248-
logger.info(accelerator.state)
249-
250-
# Setup logging, we only want one process per machine to log things on the screen.
251-
# accelerator.is_local_main_process is only True for one process per machine.
252-
logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
249+
logger.info(accelerator.state, main_process_only=False)
253250
if accelerator.is_local_main_process:
254251
datasets.utils.logging.set_verbosity_warning()
255252
transformers.utils.logging.set_verbosity_info()

examples/pytorch/multiple-choice/run_swag_no_trainer.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737

3838
import transformers
3939
from accelerate import Accelerator
40+
from accelerate.logging import get_logger
4041
from accelerate.utils import set_seed
4142
from huggingface_hub import Repository
4243
from transformers import (
@@ -54,7 +55,7 @@
5455
from transformers.utils import PaddingStrategy, get_full_repo_name
5556

5657

57-
logger = logging.getLogger(__name__)
58+
logger = get_logger(__name__)
5859
# You should update this to your particular problem to have better documentation of `model_type`
5960
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
6061
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
@@ -272,11 +273,7 @@ def main():
272273
datefmt="%m/%d/%Y %H:%M:%S",
273274
level=logging.INFO,
274275
)
275-
logger.info(accelerator.state)
276-
277-
# Setup logging, we only want one process per machine to log things on the screen.
278-
# accelerator.is_local_main_process is only True for one process per machine.
279-
logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
276+
logger.info(accelerator.state, main_process_only=False)
280277
if accelerator.is_local_main_process:
281278
datasets.utils.logging.set_verbosity_warning()
282279
transformers.utils.logging.set_verbosity_info()

examples/pytorch/question-answering/run_qa_beam_search_no_trainer.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
import transformers
3737
from accelerate import Accelerator
38+
from accelerate.logging import get_logger
3839
from accelerate.utils import set_seed
3940
from huggingface_hub import Repository
4041
from transformers import (
@@ -58,7 +59,7 @@
5859

5960
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
6061

61-
logger = logging.getLogger(__name__)
62+
logger = get_logger(__name__)
6263

6364

6465
def save_prefixed_metrics(results, output_dir, file_name: str = "all_results.json", metric_key_prefix: str = "eval"):
@@ -289,11 +290,7 @@ def main():
289290
datefmt="%m/%d/%Y %H:%M:%S",
290291
level=logging.INFO,
291292
)
292-
logger.info(accelerator.state)
293-
294-
# Setup logging, we only want one process per machine to log things on the screen.
295-
# accelerator.is_local_main_process is only True for one process per machine.
296-
logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
293+
logger.info(accelerator.state, main_process_only=False)
297294
if accelerator.is_local_main_process:
298295
datasets.utils.logging.set_verbosity_warning()
299296
transformers.utils.logging.set_verbosity_info()

examples/pytorch/question-answering/run_qa_no_trainer.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
import transformers
3737
from accelerate import Accelerator
38+
from accelerate.logging import get_logger
3839
from accelerate.utils import set_seed
3940
from huggingface_hub import Repository
4041
from transformers import (
@@ -60,7 +61,7 @@
6061

6162
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
6263

63-
logger = logging.getLogger(__name__)
64+
logger = get_logger(__name__)
6465
# You should update this to your particular problem to have better documentation of `model_type`
6566
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
6667
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
@@ -318,11 +319,7 @@ def main():
318319
datefmt="%m/%d/%Y %H:%M:%S",
319320
level=logging.INFO,
320321
)
321-
logger.info(accelerator.state)
322-
323-
# Setup logging, we only want one process per machine to log things on the screen.
324-
# accelerator.is_local_main_process is only True for one process per machine.
325-
logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
322+
logger.info(accelerator.state, main_process_only=False)
326323
if accelerator.is_local_main_process:
327324
datasets.utils.logging.set_verbosity_warning()
328325
transformers.utils.logging.set_verbosity_info()

examples/pytorch/semantic-segmentation/run_semantic_segmentation_no_trainer.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import argparse
1818
import json
19-
import logging
2019
import math
2120
import os
2221
import random
@@ -34,6 +33,7 @@
3433

3534
import transformers
3635
from accelerate import Accelerator
36+
from accelerate.logging import get_logger
3737
from accelerate.utils import set_seed
3838
from huggingface_hub import Repository, hf_hub_download
3939
from transformers import (
@@ -48,7 +48,7 @@
4848
from transformers.utils.versions import require_version
4949

5050

51-
logger = logging.getLogger(__name__)
51+
logger = get_logger(__name__)
5252

5353
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")
5454

@@ -308,11 +308,7 @@ def main():
308308
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
309309
# If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
310310
accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator()
311-
logger.info(accelerator.state)
312-
313-
# Setup logging, we only want one process per machine to log things on the screen.
314-
# accelerator.is_local_main_process is only True for one process per machine.
315-
logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
311+
logger.info(accelerator.state, main_process_only=False)
316312
if accelerator.is_local_main_process:
317313
datasets.utils.logging.set_verbosity_warning()
318314
transformers.utils.logging.set_verbosity_info()

examples/pytorch/speech-pretraining/run_wav2vec2_pretraining_no_trainer.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
""" Pre-Training a 🤗 Wav2Vec2 model on unlabeled audio data """
1717

1818
import argparse
19-
import logging
2019
import math
2120
import os
2221
from dataclasses import dataclass
@@ -31,6 +30,7 @@
3130

3231
import transformers
3332
from accelerate import Accelerator
33+
from accelerate.logging import get_logger
3434
from huggingface_hub import Repository
3535
from transformers import (
3636
AdamW,
@@ -46,7 +46,7 @@
4646
from transformers.utils import get_full_repo_name
4747

4848

49-
logger = logging.getLogger(__name__)
49+
logger = get_logger(__name__)
5050

5151

5252
def parse_args():
@@ -362,11 +362,7 @@ def main():
362362

363363
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
364364
accelerator = Accelerator()
365-
logger.info(accelerator.state)
366-
367-
# Setup logging, we only want one process per machine to log things on the screen.
368-
# accelerator.is_local_main_process is only True for one process per machine.
369-
logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
365+
logger.info(accelerator.state, main_process_only=False)
370366
if accelerator.is_local_main_process:
371367
datasets.utils.logging.set_verbosity_warning()
372368
transformers.utils.logging.set_verbosity_info()

examples/pytorch/summarization/run_summarization_no_trainer.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636

3737
import transformers
3838
from accelerate import Accelerator
39+
from accelerate.logging import get_logger
3940
from accelerate.utils import set_seed
4041
from filelock import FileLock
4142
from huggingface_hub import Repository
@@ -54,7 +55,7 @@
5455
from transformers.utils.versions import require_version
5556

5657

57-
logger = logging.getLogger(__name__)
58+
logger = get_logger(__name__)
5859
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
5960

6061
# You should update this to your particular problem to have better documentation of `model_type`
@@ -322,11 +323,7 @@ def main():
322323
datefmt="%m/%d/%Y %H:%M:%S",
323324
level=logging.INFO,
324325
)
325-
logger.info(accelerator.state)
326-
327-
# Setup logging, we only want one process per machine to log things on the screen.
328-
# accelerator.is_local_main_process is only True for one process per machine.
329-
logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
326+
logger.info(accelerator.state, main_process_only=False)
330327
if accelerator.is_local_main_process:
331328
datasets.utils.logging.set_verbosity_warning()
332329
transformers.utils.logging.set_verbosity_info()

0 commit comments

Comments
 (0)