Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions paddlevlp/datasets/caption_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class CaptionDataset(DatasetBuilder):
SPLITS = {
"train": META_INFO(
os.path.join("coco", "images"),
os.path.join("coco", "annotations/coco_karpathy_train_debug.json"),
os.path.join("coco", "annotations/coco_karpathy_train.json"),
"",
"aa31ac474cf6250ebb81d18348a07ed8",
),
Expand Down Expand Up @@ -83,16 +83,28 @@ def _gen_image_id(self, anno):
img_ids[img_id] = n
n += 1
return img_ids

def _gen_image_id_eval(self, anno):
img_ids = {}
n = 0
for ann in anno:
img_id = ann["image"].split("/")[-1].strip(".jpg").split("_")[-1]
if img_id not in img_ids.keys():
img_ids[img_id] = n
n += 1
return img_ids
def _read(self, filename, *args):
image_root, anno_path, mode = filename
annotations = json.load(open(anno_path, "r"))
image_ids = self._gen_image_id(annotations)

if mode== "val" or mode=="test":
image_ids = self._gen_image_id_eval(annotations)
else:
image_ids = self._gen_image_id(annotations)
for ann in annotations:
image_path = os.path.join(image_root, ann["image"])
yield_data = {"image": image_path, "image_id": image_ids[ann["image_id"]]}
if mode == "train":
yield_data = {"image": image_path, "image_id": image_ids[ann["image_id"]]}
# only train mode has text input
yield_data["text_input"] = ann["caption"]
else:
yield_data = {"image": image_path, "image_id": ann["image"].split("/")[-1].strip(".jpg").split("_")[-1]}
yield yield_data
126 changes: 126 additions & 0 deletions paddlevlp/examples/blip2/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from dataclasses import dataclass, field
import os
import paddle
import requests
from paddlenlp.trainer import PdArgumentParser
from PIL import Image
from paddlevlp.models.blip2.modeling import Blip2ForConditionalGeneration
from paddlevlp.processors.blip_processing import Blip2Processor
from paddlevlp.utils.log import logger
import os
import yaml
import argparse
import paddle
import argparse
import os
import random

import numpy as np
# import torch
import paddle

@dataclass
class DataArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
Using `PdArgumentParser` we can turn this class
into argparse arguments to be able to specify them on
the command line.
"""

input_image: str = field( default="http://images.cocodataset.org/val2017/000000039769.jpg",
metadata={"help": "The name of input image."}
) # "http://images.cocodataset.org/val2017/000000039769.jpg"
prompt: str = field(
default=None, metadata={"help": "The prompt of the image to be generated."}
) # "Question: how many cats are there? Answer:"


@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""

model_name_or_path: str = field(
default="Salesforce/blip2-opt-2.7b",
metadata={"help": "Path to pretrained model or model identifier"},
)
pretrained_model_path: str = field(
default=None,
metadata={
"help": "The path to pre-trained model that we will use for inference."
},)
fp16: str = field(
default=True,
metadata={
"help": "Export with mixed precision."
},
)


def main():
parser = PdArgumentParser((ModelArguments, DataArguments))
model_args, data_args = parser.parse_args_into_dataclasses()
url = (
data_args.input_image
) # "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

prompt = "a photo of "
processor = Blip2Processor.from_pretrained(
model_args.model_name_or_path
) # "Salesforce/blip2-opt-2.7b"
model = Blip2ForConditionalGeneration.from_pretrained(model_args.model_name_or_path)
model.eval()
dtype="float32"
if model_args.fp16:
decorated = paddle.amp.decorate(
models=[model.visual_encoder,model.language_model], optimizers=None, level="O2"
)
model.visual_encoder,model.language_model= decorated
dtype="float16"

shape1 = [None,3,None,None]
input_spec = [
paddle.static.InputSpec(
shape=shape1, dtype='float32'),
]
image_encoder = paddle.jit.to_static(model.encode_image, input_spec=input_spec)
save_path = "blip2_export"
paddle.jit.save(image_encoder, os.path.join(save_path, 'image_encoder'))

# TODO add test config
deploy_info = {
'Deploy': {
'model': 'image_encoder.pdmodel',
'params': 'image_encoder.pdiparams',
'input_img_shape': shape1,
'output_dtype':dtype
}
}
msg = '\n---------------Deploy Information---------------\n'
msg += str(yaml.dump(deploy_info))
logger.info(msg)

yml_file = os.path.join(save_path, 'deploy.yaml')
with open(yml_file, 'w') as file:
yaml.dump(deploy_info, file)

logger.info(f'The inference model is saved in {save_path}')
if __name__ == "__main__":
main()
95 changes: 95 additions & 0 deletions paddlevlp/examples/blip2/merge_weight.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
from paddlevlp.utils.log import logger
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["FLAGS_use_cuda_managed_memory"] = "true"

import paddle
import torch

from paddlenlp.transformers import LlamaForCausalLM
from paddlenlp.transformers.opt.modeling import OPTForCausalLM

def merge(args):
model_dict = {}
# load the first item: vision_model
state_dict = paddle.load(args.blip2_path)
for n, p in state_dict.items():
if n.startswith("vision_model") or n.startswith("qformer") or n == "query_tokens":
model_dict[n] = p
logger.info("[1/3] load ViT, qformer and query_tokens done!")

# load the second item: llm model
if "opt" in args.llm_name:
llm_model = OPTForCausalLM.from_pretrained(args.llm_path)
elif: "llama" in args.llm_name:
llm_model = LlamaForCausalLM.from_pretrained(args.llm_path)
else:
ValueError(f"The LLM model {args.llm_name} is not supported.")


for n, p in llm_model.named_parameters():
new_name = "language_model." + n
model_dict[new_name] = p
logger.info("[2/3] load language_model done!")

# load the third item: blip2
llm_state_dict = torch.load(args.llm_path)
for n, p in llm_state_dict["model"].items():
if n.startswith(args.llm_name+"_model.model"):
new_name = n.replace(args.llm_name+"_model.model", "language_model."+args.llm_name)
new_p = paddle.to_tensor(p.cpu().numpy())
model_dict[new_name] = new_p

if n.startswith(args.llm_name+args.llm_name+"_proj"):
new_name = n.replace(args.llm_name+"_proj", "language_projection")
if n.endswith("weight"):
new_p = paddle.to_tensor(p.cpu().numpy()).transpose([1, 0])
else:
new_p = paddle.to_tensor(p.cpu().numpy())
model_dict[new_name] = new_p

logger.info("[3/3] load language_projection, some llm weights from blip2 done!")

save_path = os.path.join(args.save_path, "model_state.pdparams")
paddle.save(model_dict, save_path)
logger.info("The checkpoint of blip2 has been saved to :{}".format(save_path))


if __name__ == "__main__":
parser = argparse.ArgumentParser()

parser.add_argument("--blip2_path", default="/blip2/dirname", type=str, help="The dir name of blip2-flan-t5-xxl.")
parser.add_argument("--llm_name", default="opt", type=str, help="Thename of llm model.")
parser.add_argument("--llm_path", default="/llm/dirname", type=str, help="The dir name of llm model.")
parser.add_argument(
"--blip2_path", default="/blip2/prerained_blip2.pth", type=str, help="The checkpoint path of blip2."
)
parser.add_argument("--save_path", default="/save/to/dirname", type=str, help="The saving path of blip2.")
args = parser.parse_args()

args.blip2_path = os.path.join(args.blip2_path, "model_state.pdparams")
if not os.path.exists(args.blip2_path):
raise ValueError("Not found the file: {}".format(args.blip2_path))
if not os.path.isdir(args.llm_path):
raise ValueError("It is not a directory: {}".format(args.llm_path))
if not os.path.exists(args.llm_path):
raise ValueError("Not found the file: {}".format(args.llm_path))
if not os.path.exists(args.save_path):
os.makedirs(args.save_path)

merge(args)
Loading