-
Notifications
You must be signed in to change notification settings - Fork 220
add blip2 #4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add blip2 #4
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
绝对路径要去掉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
绝对路径去掉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经删掉了
| for n, p in state_dict.items(): | ||
| if n.startswith("vision_model") or n.startswith("qformer") or n == "query_tokens": | ||
| model_dict[n] = p | ||
| print("[1/3] load ViT, qformer and query_tokens from blip2-flan-t5-xxl done!") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以用logger打印
| parser = argparse.ArgumentParser() | ||
|
|
||
| parser.add_argument("--blip2_path", default="/blip2/dirname", type=str, help="The dir name of blip2-flan-t5-xxl.") | ||
| parser.add_argument("--vicuna_path", default="/vicuna/dirname", type=str, help="The dir name of vicuna.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这是minigpt4的merge权重脚本吧,可以让用户指定llm类型
| from paddlevlp.processors.blip_processing import Blip2Processor | ||
| from paddlevlp.utils.log import logger | ||
|
|
||
| from paddlevlp.examples.blip2.Logger import MetricLogger, SmoothedValue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logger是公共部分,统一在utils/log.py下
| processor = Blip2Processor.from_pretrained(model_args.model_name_or_path) | ||
| blip_collator = BlipCollator(processor) | ||
|
|
||
| image_processor = BlipImageProcessor(image_mean=[0.48145466, 0.4578275, 0.40821073], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
参数通过processor自动下载的配置文件获取,训练推理可以在processor内做区分
| optimizers=(optimizer, lr_sched), | ||
| processor=processor, | ||
| eval_processor=eval_processor, | ||
| tokenizer=AutoTokenizer.from_pretrained("facebook/opt-2.7b", use_fast=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tokenizer已经在collator中实现,是否还需要
paddlevlp/models/blip2/modeling.py
Outdated
| image_attention_mask = paddle.ones(image_embeds.shape[:-1], dtype="int64") | ||
|
|
||
| query_tokens = self.query_tokens.expand([image_embeds.shape[0], -1, -1]) | ||
| # print('DEBUG!! Blip2ForCond query_tokens: ', query_tokens.shape, np.abs(query_tokens.numpy()).mean()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
删除debug代码
paddlevlp/trainer/blip2_trainer.py
Outdated
| __all__ = ["BLIP2_Trainer"] | ||
|
|
||
|
|
||
| class BLIP2_Trainer(Trainer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
命名不带_, BLIP2Trainer
paddlevlp/trainer/blip2_trainer.py
Outdated
|
|
||
| class BLIP2_Trainer(Trainer): | ||
| """ | ||
| BLIP2_Trainer is a simple but feature-complete training and eval loop for PaddlePaddle, optimized for PaddleNLP. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
注释说明下和Trainer的差异
remove if in PatchEmbed(
## 算子目录 - [1. 转换算子](#1-转换算子) - [1.1 llava转换算子](#11-llava转换算子) - [1.1.1 llava_convert](#111-llava_convert) - [2. 过滤算子](#2-过滤算子) - [2.1 基础过滤算子](#21-基础过滤算子) - [2.1.1 valid_data_filter](#211-valid_data_filter) - [2.1.1.1 image_compliance_operator](#2111-image_compliance_operator) - [2.1.1.2 conversation_compliance_operator](#2112-conversation_compliance_operator) - [2.2 文本过滤算子](#22-文本过滤算子) - [2.2.1 conversation_length_filter](#221-conversation_length_filter) - [2.2.2 average_line_length_filter](#222-average_line_length_filter) - [2.2.3 maximum_line_length_filter](#223-maximum_line_length_filter) - [2.2.4 conversation_percentage_filter](#224-conversation_percentage_filter) - [2.2.5 token_num_filter](#225-token_num_filter) - [2.2.6 alphanumeric_ratio_filter](#226-alphanumeric_ratio_filter) - [2.2.7 stopwords_ratio_filter](#227-stopwords_ratio_filter) - [2.2.8 special_characters_filter](#228-special_characters_filter) - [2.2.9 language_id_filter](#229-language_id_filter) - [2.2.10 text_action_filter](#2210-text_action_filter) - [2.2.11 text_entity_dependency_filter](#2211-text_entity_dependency_filter) - [2.2.12 char_ngram_repetition_filter](#2212-char_ngram_repetition_filter) - [2.2.13 word_ngram_repetition_filter](#2213-word_ngram_repetition_filter) - [2.2.14 conversation_hash_filter](#2214-conversation_hash_filter) - [2.2.14.1 simhash_duplicate_operator](#22141-simhash_duplicate_operator) - [2.2.14.2 minhash_duplicate_operator](#22142-minhash_duplicate_operator) - [2.2.15 llm_judge_filter](#2215-llm_judge_filter) - [2.3 图像过滤算子](#23-图像过滤算子) - [2.3.1 image_filesize_filter](#231-image_filesize_filter) - [2.3.2 image_ration_filter](#232-image_ration_filter) - [2.3.3 image_resolution_filter](#233-image_resolution_filter) - [2.3.4 image_hash_filter](#234-image_hash_filter) - [2.4 图文过滤算子](#24-图文过滤算子) - [2.4.1 image_clip_filter](#241-image_clip_filter) - [3. 分析算子](#3-分析算子) - [3.1 基础分析算子](#31-基础分析算子) - [3.1.1 base_analysis_pipeline](#311-base_analysis_pipeline) - [3.1.1.1 analyze_dataset_statistics](#3111-analyze_dataset_statistics) - [3.1.1.2 analyze_language_distribution](#3112-analyze_language_distribution) - [3.1.1.3 analyze_image_paths](#3113-analyze_image_paths) - [3.1.1.4 analyze_data_anomalies](#3114-analyze_data_anomalies) - [3.1.1.5 analyze_conversation_tokens](#3115-analyze_conversation_tokens) - [3.2 进阶分析算子](#32-进阶分析算子) - [3.2.1 description_analysis](#321-description_analysis) - [3.2.2 quality_analysis](#322-quality_analysis) - [4. 可视化算子](#4-可视化算子) - [4.1 lda可视化算子](#41-lda可视化算子) - [4.1.1 lda_topic_clustering](#411-lda_topic_clustering) - [5. 生成算子](#5-生成算子) - [5.1 多模态生成算子](#51-多模态生成算子) - [5.1.1 generate_qna_for_images](#511-generate_qna_for_images) --- - #1055
No description provided.