- 
                Notifications
    You must be signed in to change notification settings 
- Fork 221
add blip2 #4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add blip2 #4
Changes from 1 commit
a90aa49
              43b703d
              1471ba3
              802d260
              29729e3
              fca6f49
              f5b9835
              9b79d2f
              23d4c4e
              38ecdbd
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,88 @@ | ||
| # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|  | ||
| import argparse | ||
| import os | ||
|  | ||
| os.environ["CUDA_VISIBLE_DEVICES"] = "0" | ||
| os.environ["FLAGS_use_cuda_managed_memory"] = "true" | ||
|  | ||
| import paddle | ||
| import torch | ||
|  | ||
| from paddlenlp.transformers import LlamaForCausalLM | ||
|  | ||
|  | ||
| def merge(args): | ||
| model_dict = {} | ||
| # load the first item: blip2-flan-t5-xxl | ||
| state_dict = paddle.load(args.blip2_path) | ||
| for n, p in state_dict.items(): | ||
| if n.startswith("vision_model") or n.startswith("qformer") or n == "query_tokens": | ||
| model_dict[n] = p | ||
| print("[1/3] load ViT, qformer and query_tokens from blip2-flan-t5-xxl done!") | ||
|  | ||
| # load the second item: vicuna | ||
| llama_model = LlamaForCausalLM.from_pretrained(args.vicuna_path) | ||
|  | ||
| for n, p in llama_model.named_parameters(): | ||
| new_name = "language_model." + n | ||
| model_dict[new_name] = p | ||
| print("[2/3] load vicuna(llama typel) done!") | ||
|  | ||
| # load the third item: minigpt4 | ||
| minigpt4_state_dict = torch.load(args.minigpt4_path) | ||
| for n, p in minigpt4_state_dict["model"].items(): | ||
| if n.startswith("llama_model.model"): | ||
| new_name = n.replace("llama_model.model", "language_model.llama") | ||
| new_p = paddle.to_tensor(p.cpu().numpy()) | ||
| model_dict[new_name] = new_p | ||
|  | ||
| if n.startswith("llama_proj"): | ||
| new_name = n.replace("llama_proj", "language_projection") | ||
| if n.endswith("weight"): | ||
| new_p = paddle.to_tensor(p.cpu().numpy()).transpose([1, 0]) | ||
| else: | ||
| new_p = paddle.to_tensor(p.cpu().numpy()) | ||
| model_dict[new_name] = new_p | ||
|  | ||
| print("[3/3] load language_projection, some llama weights from minigpt4 done!") | ||
|  | ||
| save_path = os.path.join(args.save_path, "model_state.pdparams") | ||
| paddle.save(model_dict, save_path) | ||
| print("The checkpoint of minigpt4 has been saved to :{}".format(save_path)) | ||
|  | ||
|  | ||
| if __name__ == "__main__": | ||
| parser = argparse.ArgumentParser() | ||
|  | ||
| parser.add_argument("--blip2_path", default="/blip2/dirname", type=str, help="The dir name of blip2-flan-t5-xxl.") | ||
| parser.add_argument("--vicuna_path", default="/vicuna/dirname", type=str, help="The dir name of vicuna.") | ||
|          | ||
| parser.add_argument( | ||
| "--minigpt4_path", default="/minigpt4/prerained_minigpt4.pth", type=str, help="The checkpoint path of vicuna." | ||
| ) | ||
| parser.add_argument("--save_path", default="/save/to/dirname", type=str, help="The saving path of minigpt4.") | ||
| args = parser.parse_args() | ||
|  | ||
| args.blip2_path = os.path.join(args.blip2_path, "model_state.pdparams") | ||
| if not os.path.exists(args.blip2_path): | ||
| raise ValueError("Not found the file: {}".format(args.blip2_path)) | ||
| if not os.path.isdir(args.vicuna_path): | ||
| raise ValueError("It is not a directory: {}".format(args.vicuna_path)) | ||
| if not os.path.exists(args.minigpt4_path): | ||
| raise ValueError("Not found the file: {}".format(args.minigpt4_path)) | ||
| if not os.path.exists(args.save_path): | ||
| os.makedirs(args.save_path) | ||
|  | ||
| merge(args) | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 绝对路径要去掉 | 
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -13,7 +13,8 @@ | |
| # limitations under the License. | ||
|  | ||
| from dataclasses import dataclass, field | ||
|  | ||
| import sys | ||
| sys.path.insert(0,"/paddle/workspace/wjm/merge_blip2/PaddleMIX-master-597ba145fd4c2d9bfb9e2a9fd40401af15fb537e") | ||
| import paddle | ||
| import requests | ||
| from paddlenlp.trainer import PdArgumentParser | ||
|  | @@ -23,6 +24,7 @@ | |
| from paddlevlp.processors.blip_processing import Blip2Processor | ||
| from paddlevlp.utils.log import logger | ||
|  | ||
| from paddlevlp.examples.blip2.Logger import MetricLogger, SmoothedValue | ||
|          | ||
|  | ||
| @dataclass | ||
| class DataArguments: | ||
|  | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以用logger打印