Skip to content

Accelerate save() with PyTorch HuggingFace does not save large pytorch_model.bin files in a multi-gpu setting #325

@TejaGollapudi

Description

@TejaGollapudi

When I try to save a PreTrainedBert checkpoint using Accelerate's save, it only saves the config file and not the bin file. I've noticed this issue when working with the Bert-Large model and didn't have this issue when I was working with the Bert-Base model ( a couple of months ago).

I'm not sure if it's due to the version update or the larger file size of Bert-Large.

The below code should reproduce the error In a multi-GPU environment.


System Details:
  • Accelarate version 0.6.2
  • Pytorch version 1.8.1+cu111
  • Transformers version 4.18.0
  • Cuda version 11.2
  • 3 V100D-32C GPUs
  • 5.4.0-62-generic GNU/Linux
import math
import torch
import transformers
from accelerate import Accelerator, notebook_launcher

def sample(model, accelerator):
    accelerator.wait_for_everyone()
    accelerator.print('Saving Checkpoint')
    unwrapped_model = accelerator.unwrap_model(model)
    unwrapped_model.save_pretrained('ckpts',save_function = accelerator.save)

def demo():
    accelerator = Accelerator()
    device = accelerator.device
    accelerator.print("DEVICES IN USE:")
    print(device)
    model = BertForPreTraining.from_pretrained('bert-base-uncased')
    model = accelerator.prepare(model)
    sample(model,accelerator)

notebook_launcher(demo, num_processes=3)

When I try to save it after training for a few steps, it sometimes produces the following error (assumption for the different behavior is that training un-synchronizes the processes to a larger extent) :

ProcessRaisedException                    Traceback (most recent call last)
<ipython-input-8-78beb23515b9> in <module>
      1 # Bert base / large model
      2 FLAGS.bert_model = 'LARGE'
----> 3 hyperparameter_search(FLAGS,hyper_params_list, 'V1_LARGE', metric='EVAL_MLM_A')

<ipython-input-6-74c6cf480cc7> in hyperparameter_search(base_flags, hyper_params, version, metric)
     22         base_flags.eval_point = item[3]
     23         base_flags.output_dir = '../ckpts00_large/'+version+'hypsearch'+'/'+model_name
---> 24         notebook_launcher(training_function, [base_flags], num_processes=base_flags.num_gpu)
     25 
     26 

/opt/conda/lib/python3.8/site-packages/accelerate/launchers.py in notebook_launcher(function, args, num_processes, use_fp16, mixed_precision, use_port)
    126 
    127                 print(f"Launching training on {num_processes} GPUs.")
--> 128                 start_processes(launcher, args=args, nprocs=num_processes, start_method="fork")
    129 
    130         else:

/opt/conda/lib/python3.8/site-packages/torch/multiprocessing/spawn.py in start_processes(fn, args, nprocs, join, daemon, start_method)
    186 
    187     # Loop on join until it returns True or raises an exception.
--> 188     while not context.join():
    189         pass
    190 

/opt/conda/lib/python3.8/site-packages/torch/multiprocessing/spawn.py in join(self, timeout)
    148         msg = "\n\n-- Process %d terminated with the following error:\n" % error_index
    149         msg += original_trace
--> 150         raise ProcessRaisedException(msg, error_index, failed_process.pid)
    151 
    152 

ProcessRaisedException: 

-- Process 2 terminated with the following error:
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap
    fn(i, *args)
  File "/opt/conda/lib/python3.8/site-packages/accelerate/utils.py", line 638, in __call__
    self.launcher(*args)
  File "/home/jovyan/vbert2021/notebooks/bert_pretraining.py", line 376, in training_function
    log_list = train(FLAGS, model, data, optim, scheduler, accelerator,logger, eval_data)
  File "/home/jovyan/vbert2021/notebooks/bert_pretraining.py", line 208, in train
    unwrapped_model.save_pretrained(os.path.join(FLAGS.output_dir,'Steps_'+str(step_count-FLAGS.num_warmup_steps)),save_function = accelerator.save)
  File "/opt/conda/lib/python3.8/site-packages/transformers/modeling_utils.py", line 1371, in save_pretrained
    os.remove(full_filename)
FileNotFoundError: [Errno 2] No such file or directory: '../ckpts00_large/V1_LARGEhypsearch/V1_LARGE_1e-05_258/Steps_10/pytorch_model.bin'

In 0.6.2, I believe it runs the save() for each process and ends up creating multiple saves and in the process of clearing a previous save, it deletes the pytorch_model.bin file


https://huggingface.co/docs/accelerate/v0.6.0/en/accelerator#accelerate.Accelerator.save states that it saves once per machine and not once per process. Maybe my understanding of it is wrong.

My current workaround is:

accelerator.wait_for_everyone()
if accelerator.is_main_process:
    unwrapped_model = accelerator.unwrap_model(model)
    unwrapped_model.save_pretrained(ckpts,save_function = accelerator.save)

Thank you

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions