Skip to content

Commit d4b3e35

Browse files
authored
Don't push checkpoints to hub in no_trainer scripts (#16703)
Adds checkpoint prefixes to the gitignore if `push_to_hub` is used along with `checkpointint_steps`
1 parent c04619e commit d4b3e35

File tree

9 files changed

+63
-10
lines changed

9 files changed

+63
-10
lines changed

examples/pytorch/language-modeling/run_clm_no_trainer.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939

4040
import transformers
4141
from accelerate import Accelerator, DistributedType
42+
from accelerate.utils import set_seed
4243
from huggingface_hub import Repository
4344
from transformers import (
4445
CONFIG_MAPPING,
@@ -50,7 +51,6 @@
5051
SchedulerType,
5152
default_data_collator,
5253
get_scheduler,
53-
set_seed,
5454
)
5555
from transformers.utils import get_full_repo_name
5656
from transformers.utils.versions import require_version
@@ -258,6 +258,12 @@ def main():
258258
else:
259259
repo_name = args.hub_model_id
260260
repo = Repository(args.output_dir, clone_from=repo_name)
261+
262+
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
263+
if "step_*" not in gitignore:
264+
gitignore.write("step_*\n")
265+
if "epoch_*" not in gitignore:
266+
gitignore.write("epoch_*\n")
261267
elif args.output_dir is not None:
262268
os.makedirs(args.output_dir, exist_ok=True)
263269
accelerator.wait_for_everyone()
@@ -542,7 +548,6 @@ def group_texts(examples):
542548
if args.output_dir is not None:
543549
output_dir = os.path.join(args.output_dir, output_dir)
544550
accelerator.save_state(output_dir)
545-
546551
if completed_steps >= args.max_train_steps:
547552
break
548553

examples/pytorch/language-modeling/run_mlm_no_trainer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939

4040
import transformers
4141
from accelerate import Accelerator, DistributedType
42+
from accelerate.utils import set_seed
4243
from huggingface_hub import Repository
4344
from transformers import (
4445
CONFIG_MAPPING,
@@ -50,7 +51,6 @@
5051
DataCollatorForLanguageModeling,
5152
SchedulerType,
5253
get_scheduler,
53-
set_seed,
5454
)
5555
from transformers.utils import get_full_repo_name
5656
from transformers.utils.versions import require_version
@@ -269,6 +269,12 @@ def main():
269269
else:
270270
repo_name = args.hub_model_id
271271
repo = Repository(args.output_dir, clone_from=repo_name)
272+
273+
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
274+
if "step_*" not in gitignore:
275+
gitignore.write("step_*\n")
276+
if "epoch_*" not in gitignore:
277+
gitignore.write("epoch_*\n")
272278
elif args.output_dir is not None:
273279
os.makedirs(args.output_dir, exist_ok=True)
274280
accelerator.wait_for_everyone()

examples/pytorch/multiple-choice/run_swag_no_trainer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737

3838
import transformers
3939
from accelerate import Accelerator
40+
from accelerate.utils import set_seed
4041
from huggingface_hub import Repository
4142
from transformers import (
4243
CONFIG_MAPPING,
@@ -49,7 +50,6 @@
4950
SchedulerType,
5051
default_data_collator,
5152
get_scheduler,
52-
set_seed,
5353
)
5454
from transformers.utils import PaddingStrategy, get_full_repo_name
5555

@@ -296,6 +296,12 @@ def main():
296296
else:
297297
repo_name = args.hub_model_id
298298
repo = Repository(args.output_dir, clone_from=repo_name)
299+
300+
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
301+
if "step_*" not in gitignore:
302+
gitignore.write("step_*\n")
303+
if "epoch_*" not in gitignore:
304+
gitignore.write("epoch_*\n")
299305
elif args.output_dir is not None:
300306
os.makedirs(args.output_dir, exist_ok=True)
301307
accelerator.wait_for_everyone()

examples/pytorch/question-answering/run_qa_beam_search_no_trainer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535
import transformers
3636
from accelerate import Accelerator
37+
from accelerate.utils import set_seed
3738
from huggingface_hub import Repository
3839
from transformers import (
3940
AdamW,
@@ -45,7 +46,6 @@
4546
XLNetTokenizerFast,
4647
default_data_collator,
4748
get_scheduler,
48-
set_seed,
4949
)
5050
from transformers.utils import check_min_version, get_full_repo_name
5151
from transformers.utils.versions import require_version
@@ -290,6 +290,12 @@ def main():
290290
else:
291291
repo_name = args.hub_model_id
292292
repo = Repository(args.output_dir, clone_from=repo_name)
293+
294+
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
295+
if "step_*" not in gitignore:
296+
gitignore.write("step_*\n")
297+
if "epoch_*" not in gitignore:
298+
gitignore.write("epoch_*\n")
293299
elif args.output_dir is not None:
294300
os.makedirs(args.output_dir, exist_ok=True)
295301
accelerator.wait_for_everyone()

examples/pytorch/question-answering/run_qa_no_trainer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
import transformers
3737
from accelerate import Accelerator
38+
from accelerate.utils import set_seed
3839
from huggingface_hub import Repository
3940
from transformers import (
4041
CONFIG_MAPPING,
@@ -48,7 +49,6 @@
4849
SchedulerType,
4950
default_data_collator,
5051
get_scheduler,
51-
set_seed,
5252
)
5353
from transformers.utils import check_min_version, get_full_repo_name
5454
from transformers.utils.versions import require_version
@@ -320,6 +320,12 @@ def main():
320320
else:
321321
repo_name = args.hub_model_id
322322
repo = Repository(args.output_dir, clone_from=repo_name)
323+
324+
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
325+
if "step_*" not in gitignore:
326+
gitignore.write("step_*\n")
327+
if "epoch_*" not in gitignore:
328+
gitignore.write("epoch_*\n")
323329
elif args.output_dir is not None:
324330
os.makedirs(args.output_dir, exist_ok=True)
325331
accelerator.wait_for_everyone()

examples/pytorch/summarization/run_summarization_no_trainer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636

3737
import transformers
3838
from accelerate import Accelerator
39+
from accelerate.utils import set_seed
3940
from filelock import FileLock
4041
from huggingface_hub import Repository
4142
from transformers import (
@@ -48,7 +49,6 @@
4849
DataCollatorForSeq2Seq,
4950
SchedulerType,
5051
get_scheduler,
51-
set_seed,
5252
)
5353
from transformers.utils import get_full_repo_name, is_offline_mode
5454
from transformers.utils.versions import require_version
@@ -346,6 +346,12 @@ def main():
346346
else:
347347
repo_name = args.hub_model_id
348348
repo = Repository(args.output_dir, clone_from=repo_name)
349+
350+
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
351+
if "step_*" not in gitignore:
352+
gitignore.write("step_*\n")
353+
if "epoch_*" not in gitignore:
354+
gitignore.write("epoch_*\n")
349355
elif args.output_dir is not None:
350356
os.makedirs(args.output_dir, exist_ok=True)
351357
accelerator.wait_for_everyone()

examples/pytorch/text-classification/run_glue_no_trainer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
import transformers
3030
from accelerate import Accelerator
31+
from accelerate.utils import set_seed
3132
from huggingface_hub import Repository
3233
from transformers import (
3334
AdamW,
@@ -39,7 +40,6 @@
3940
SchedulerType,
4041
default_data_collator,
4142
get_scheduler,
42-
set_seed,
4343
)
4444
from transformers.utils import get_full_repo_name
4545
from transformers.utils.versions import require_version
@@ -223,6 +223,12 @@ def main():
223223
else:
224224
repo_name = args.hub_model_id
225225
repo = Repository(args.output_dir, clone_from=repo_name)
226+
227+
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
228+
if "step_*" not in gitignore:
229+
gitignore.write("step_*\n")
230+
if "epoch_*" not in gitignore:
231+
gitignore.write("epoch_*\n")
226232
elif args.output_dir is not None:
227233
os.makedirs(args.output_dir, exist_ok=True)
228234
accelerator.wait_for_everyone()

examples/pytorch/token-classification/run_ner_no_trainer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535
import transformers
3636
from accelerate import Accelerator
37+
from accelerate.utils import set_seed
3738
from huggingface_hub import Repository
3839
from transformers import (
3940
CONFIG_MAPPING,
@@ -47,7 +48,6 @@
4748
SchedulerType,
4849
default_data_collator,
4950
get_scheduler,
50-
set_seed,
5151
)
5252
from transformers.utils import get_full_repo_name
5353
from transformers.utils.versions import require_version
@@ -277,6 +277,12 @@ def main():
277277
else:
278278
repo_name = args.hub_model_id
279279
repo = Repository(args.output_dir, clone_from=repo_name)
280+
281+
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
282+
if "step_*" not in gitignore:
283+
gitignore.write("step_*\n")
284+
if "epoch_*" not in gitignore:
285+
gitignore.write("epoch_*\n")
280286
elif args.output_dir is not None:
281287
os.makedirs(args.output_dir, exist_ok=True)
282288
accelerator.wait_for_everyone()

examples/pytorch/translation/run_translation_no_trainer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
import transformers
3737
from accelerate import Accelerator
38+
from accelerate.utils import set_seed
3839
from huggingface_hub import Repository
3940
from transformers import (
4041
CONFIG_MAPPING,
@@ -49,7 +50,6 @@
4950
SchedulerType,
5051
default_data_collator,
5152
get_scheduler,
52-
set_seed,
5353
)
5454
from transformers.utils import get_full_repo_name
5555
from transformers.utils.versions import require_version
@@ -319,6 +319,12 @@ def main():
319319
else:
320320
repo_name = args.hub_model_id
321321
repo = Repository(args.output_dir, clone_from=repo_name)
322+
323+
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
324+
if "step_*" not in gitignore:
325+
gitignore.write("step_*\n")
326+
if "epoch_*" not in gitignore:
327+
gitignore.write("epoch_*\n")
322328
elif args.output_dir is not None:
323329
os.makedirs(args.output_dir, exist_ok=True)
324330
accelerator.wait_for_everyone()

0 commit comments

Comments
 (0)