-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Description
When I try to save a PreTrainedBert checkpoint using Accelerate's save, it only saves the config file and not the bin file. I've noticed this issue when working with the Bert-Large model and didn't have this issue when I was working with the Bert-Base model ( a couple of months ago).
I'm not sure if it's due to the version update or the larger file size of Bert-Large.
The below code should reproduce the error In a multi-GPU environment.
System Details:
- Accelarate version 0.6.2
- Pytorch version 1.8.1+cu111
- Transformers version 4.18.0
- Cuda version 11.2
- 3 V100D-32C GPUs
- 5.4.0-62-generic GNU/Linux
import math
import torch
import transformers
from accelerate import Accelerator, notebook_launcher
def sample(model, accelerator):
accelerator.wait_for_everyone()
accelerator.print('Saving Checkpoint')
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained('ckpts',save_function = accelerator.save)
def demo():
accelerator = Accelerator()
device = accelerator.device
accelerator.print("DEVICES IN USE:")
print(device)
model = BertForPreTraining.from_pretrained('bert-base-uncased')
model = accelerator.prepare(model)
sample(model,accelerator)
notebook_launcher(demo, num_processes=3)
When I try to save it after training for a few steps, it sometimes produces the following error (assumption for the different behavior is that training un-synchronizes the processes to a larger extent) :
ProcessRaisedException Traceback (most recent call last)
<ipython-input-8-78beb23515b9> in <module>
1 # Bert base / large model
2 FLAGS.bert_model = 'LARGE'
----> 3 hyperparameter_search(FLAGS,hyper_params_list, 'V1_LARGE', metric='EVAL_MLM_A')
<ipython-input-6-74c6cf480cc7> in hyperparameter_search(base_flags, hyper_params, version, metric)
22 base_flags.eval_point = item[3]
23 base_flags.output_dir = '../ckpts00_large/'+version+'hypsearch'+'/'+model_name
---> 24 notebook_launcher(training_function, [base_flags], num_processes=base_flags.num_gpu)
25
26
/opt/conda/lib/python3.8/site-packages/accelerate/launchers.py in notebook_launcher(function, args, num_processes, use_fp16, mixed_precision, use_port)
126
127 print(f"Launching training on {num_processes} GPUs.")
--> 128 start_processes(launcher, args=args, nprocs=num_processes, start_method="fork")
129
130 else:
/opt/conda/lib/python3.8/site-packages/torch/multiprocessing/spawn.py in start_processes(fn, args, nprocs, join, daemon, start_method)
186
187 # Loop on join until it returns True or raises an exception.
--> 188 while not context.join():
189 pass
190
/opt/conda/lib/python3.8/site-packages/torch/multiprocessing/spawn.py in join(self, timeout)
148 msg = "\n\n-- Process %d terminated with the following error:\n" % error_index
149 msg += original_trace
--> 150 raise ProcessRaisedException(msg, error_index, failed_process.pid)
151
152
ProcessRaisedException:
-- Process 2 terminated with the following error:
Traceback (most recent call last):
File "/opt/conda/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap
fn(i, *args)
File "/opt/conda/lib/python3.8/site-packages/accelerate/utils.py", line 638, in __call__
self.launcher(*args)
File "/home/jovyan/vbert2021/notebooks/bert_pretraining.py", line 376, in training_function
log_list = train(FLAGS, model, data, optim, scheduler, accelerator,logger, eval_data)
File "/home/jovyan/vbert2021/notebooks/bert_pretraining.py", line 208, in train
unwrapped_model.save_pretrained(os.path.join(FLAGS.output_dir,'Steps_'+str(step_count-FLAGS.num_warmup_steps)),save_function = accelerator.save)
File "/opt/conda/lib/python3.8/site-packages/transformers/modeling_utils.py", line 1371, in save_pretrained
os.remove(full_filename)
FileNotFoundError: [Errno 2] No such file or directory: '../ckpts00_large/V1_LARGEhypsearch/V1_LARGE_1e-05_258/Steps_10/pytorch_model.bin'
In 0.6.2, I believe it runs the save() for each process and ends up creating multiple saves and in the process of clearing a previous save, it deletes the pytorch_model.bin file
https://huggingface.co/docs/accelerate/v0.6.0/en/accelerator#accelerate.Accelerator.save states that it saves once per machine and not once per process. Maybe my understanding of it is wrong.
My current workaround is:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(ckpts,save_function = accelerator.save)
Thank you