Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions paddlevlp/datasets/caption_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class CaptionDataset(DatasetBuilder):
SPLITS = {
"train": META_INFO(
os.path.join("coco", "images"),
os.path.join("coco", "annotations/coco_karpathy_train_debug.json"),
os.path.join("coco", "annotations/coco_karpathy_train.json"),
"",
"aa31ac474cf6250ebb81d18348a07ed8",
),
Expand Down Expand Up @@ -83,16 +83,28 @@ def _gen_image_id(self, anno):
img_ids[img_id] = n
n += 1
return img_ids

def _gen_image_id_eval(self, anno):
img_ids = {}
n = 0
for ann in anno:
img_id = ann["image"].split("/")[-1].strip(".jpg").split("_")[-1]
if img_id not in img_ids.keys():
img_ids[img_id] = n
n += 1
return img_ids
def _read(self, filename, *args):
image_root, anno_path, mode = filename
annotations = json.load(open(anno_path, "r"))
image_ids = self._gen_image_id(annotations)

if mode== "val" or mode=="test":
image_ids = self._gen_image_id_eval(annotations)
else:
image_ids = self._gen_image_id(annotations)
for ann in annotations:
image_path = os.path.join(image_root, ann["image"])
yield_data = {"image": image_path, "image_id": image_ids[ann["image_id"]]}
if mode == "train":
yield_data = {"image": image_path, "image_id": image_ids[ann["image_id"]]}
# only train mode has text input
yield_data["text_input"] = ann["caption"]
else:
yield_data = {"image": image_path, "image_id": ann["image"].split("/")[-1].strip(".jpg").split("_")[-1]}
yield yield_data
88 changes: 88 additions & 0 deletions paddlevlp/examples/blip2/merge_weight.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["FLAGS_use_cuda_managed_memory"] = "true"

import paddle
import torch

from paddlenlp.transformers import LlamaForCausalLM


def merge(args):
model_dict = {}
# load the first item: blip2-flan-t5-xxl
state_dict = paddle.load(args.blip2_path)
for n, p in state_dict.items():
if n.startswith("vision_model") or n.startswith("qformer") or n == "query_tokens":
model_dict[n] = p
print("[1/3] load ViT, qformer and query_tokens from blip2-flan-t5-xxl done!")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以用logger打印


# load the second item: vicuna
llama_model = LlamaForCausalLM.from_pretrained(args.vicuna_path)

for n, p in llama_model.named_parameters():
new_name = "language_model." + n
model_dict[new_name] = p
print("[2/3] load vicuna(llama typel) done!")

# load the third item: minigpt4
minigpt4_state_dict = torch.load(args.minigpt4_path)
for n, p in minigpt4_state_dict["model"].items():
if n.startswith("llama_model.model"):
new_name = n.replace("llama_model.model", "language_model.llama")
new_p = paddle.to_tensor(p.cpu().numpy())
model_dict[new_name] = new_p

if n.startswith("llama_proj"):
new_name = n.replace("llama_proj", "language_projection")
if n.endswith("weight"):
new_p = paddle.to_tensor(p.cpu().numpy()).transpose([1, 0])
else:
new_p = paddle.to_tensor(p.cpu().numpy())
model_dict[new_name] = new_p

print("[3/3] load language_projection, some llama weights from minigpt4 done!")

save_path = os.path.join(args.save_path, "model_state.pdparams")
paddle.save(model_dict, save_path)
print("The checkpoint of minigpt4 has been saved to :{}".format(save_path))


if __name__ == "__main__":
parser = argparse.ArgumentParser()

parser.add_argument("--blip2_path", default="/blip2/dirname", type=str, help="The dir name of blip2-flan-t5-xxl.")
parser.add_argument("--vicuna_path", default="/vicuna/dirname", type=str, help="The dir name of vicuna.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这是minigpt4的merge权重脚本吧,可以让用户指定llm类型

parser.add_argument(
"--minigpt4_path", default="/minigpt4/prerained_minigpt4.pth", type=str, help="The checkpoint path of vicuna."
)
parser.add_argument("--save_path", default="/save/to/dirname", type=str, help="The saving path of minigpt4.")
args = parser.parse_args()

args.blip2_path = os.path.join(args.blip2_path, "model_state.pdparams")
if not os.path.exists(args.blip2_path):
raise ValueError("Not found the file: {}".format(args.blip2_path))
if not os.path.isdir(args.vicuna_path):
raise ValueError("It is not a directory: {}".format(args.vicuna_path))
if not os.path.exists(args.minigpt4_path):
raise ValueError("Not found the file: {}".format(args.minigpt4_path))
if not os.path.exists(args.save_path):
os.makedirs(args.save_path)

merge(args)
4 changes: 3 additions & 1 deletion paddlevlp/examples/blip2/run_predict.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

绝对路径要去掉

Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
# limitations under the License.

from dataclasses import dataclass, field

import sys
sys.path.insert(0,"/paddle/workspace/wjm/merge_blip2/PaddleMIX-master-597ba145fd4c2d9bfb9e2a9fd40401af15fb537e")
import paddle
import requests
from paddlenlp.trainer import PdArgumentParser
Expand All @@ -23,6 +24,7 @@
from paddlevlp.processors.blip_processing import Blip2Processor
from paddlevlp.utils.log import logger

from paddlevlp.examples.blip2.Logger import MetricLogger, SmoothedValue
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logger是公共部分,统一在utils/log.py下


@dataclass
class DataArguments:
Expand Down
Loading