Skip to content
This repository was archived by the owner on Nov 1, 2024. It is now read-only.
This repository was archived by the owner on Nov 1, 2024. It is now read-only.

How to load sharded checkpoints? #31

@patrickvonplaten

Description

@patrickvonplaten

❓ Questions and Help

After having set-up the libraries as described in: https://github.com/facebookresearch/metaseq/blob/main/docs/setup.md ,
it is possible to load the 350m checkpoint since it's not sharded as follows:

wget https://dl.fbaipublicfiles.com/opt/v1_20220502/350m/reshard.pt ./
  1. Next we need to comment out one line in the Megatron-LM library which is only relevant for training (initialize different random seeds accross pp ranks):
    Comment out this line: https://github.com/ngoyal2707/Megatron-LM/blob/ae0b844c1f6725c3433a95e42cac760b3885170b/megatron/initialize.py#L65 in your local clone of Megatron-LM

  2. Now we write the following Python script to a run_model.py file:

import os

from transformers import AutoTokenizer, GPT2Tokenizer
from megatron.initialize import initialize_megatron
from metaseq import checkpoint_utils
import torch

path = "./"

# arguments taken from: https://arxiv.org/pdf/2205.01068.pdf | table 1
initialize_megatron(args_defaults={
    "micro_batch_size": 1, 
    "num_layers": 24, 
    "hidden_size": 1024, 
    "num_attention_heads": 16,
    "max_position_embeddings": 2048, 
    "encoder_seq_length": 2048 
})

tokenizer = GPT2Tokenizer.from_pretrained("patrickvonplaten/opt_gpt2_tokenizer")
tokenizer.save_pretrained(path)

checkpoint = checkpoint_utils.load_model_ensemble_and_task(
    [os.path.join(path, "reshard.pt")],
    arg_overrides={
        "vocab_filename": os.path.join(path, "vocab.json"),
        "merges_filename": os.path.join(path, "merges.txt"),
    }
)

model = checkpoint[0][0].eval()
  1. We can load the checkpoint when running
torchrun run_model.py --pipeline-model-parallel-size 1 --tensor-model-parallel-size 1

Problem This only works for the 350m checkpoint!!! For the other checkpoints this doesn't work.
E.g. when replacing:
[os.path.join(path, "reshard.pt")]
by
[os.path.join(path, "reshard-model_part-0.pt"), os.path.join(path, "reshard-model_part-1.pt")] (part-0 and part-1 of the 125M model),
we're getting an error because the weigths are all flattened into 1D-arrays.

Using #29 sadly also doesn't help, since the checkpoints don't seem to be in the *shard* format as required here:

sorted(glob(f"{pth_prefix}*shard*.pt"), key=_get_shard_number)

The parameter flattening seems to come from Fairscale and we've found some functionality to unflatten it here: https://github.com/facebookresearch/fairscale/blob/51b53ddb6c3aa77426c7d5cc0b543b79628053c4/fairscale/nn/misc/flatten_params_wrapper.py#L358 , but we don't manage to wrap our head around how to make it work exactly.

@stephenroller @suchenzang @zhiqwang - any pointers on how we could load the 125M model (and the others) into a model instance of metaseq?

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions