Skip to content

support paligemma #1004

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,8 @@ The complete list of supported models and datasets can be found at [Supported Mo
| mPLUG-Owl | [mPLUG-Owl series models](https://github.com/X-PLUG/mPLUG-Owl) | English | 11B | chat model |
| InternVL | [InternVL](https://github.com/OpenGVLab/InternVL) | Chinese<br>English | 25.5B<br>including quantized version | chat model |
| Llava-llama3 | [xtuner](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers) | English | 8B | chat model |
| Phi3 | Microsoft | English | 4B | chat model |
| Phi3-Vision | Microsoft | English | 4B | chat model |
| PaliGemma | Google | English | 3B | chat model |

#### Diffusion Models

Expand Down
3 changes: 2 additions & 1 deletion README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,8 @@ CUDA_VISIBLE_DEVICES=0 swift deploy \
| mPLUG-Owl | [mPLUG-Owl系列模型](https://github.com/X-PLUG/mPLUG-Owl) | 英文 | 11B | chat模型 |
| InternVL | [InternVL](https://github.com/OpenGVLab/InternVL) | 中文<br>英文 | 25.5B<br>包含量化版本 | chat模型 |
| Llava-llama3 | [xtuner](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers) | 英文 | 8B | chat model |
| Phi3 | Microsoft | 英文 | 4B | chat model |
| Phi3-Vision | 微软 | 英文 | 4B | chat model |
| PaliGemma | Google | 英文 | 3B | chat model |

#### 扩散模型

Expand Down
5 changes: 5 additions & 0 deletions docs/source/LLM/支持的模型和数据集.md
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,11 @@
|gemma-7b|[AI-ModelScope/gemma-7b](https://modelscope.cn/models/AI-ModelScope/gemma-7b/summary)|q_proj, k_proj, v_proj|default-generation|&#x2714;|&#x2714;|transformers>=4.38|-|[google/gemma-7b](https://huggingface.co/google/gemma-7b)|
|gemma-2b-instruct|[AI-ModelScope/gemma-2b-it](https://modelscope.cn/models/AI-ModelScope/gemma-2b-it/summary)|q_proj, k_proj, v_proj|gemma|&#x2714;|&#x2714;|transformers>=4.38|-|[google/gemma-2b-it](https://huggingface.co/google/gemma-2b-it)|
|gemma-7b-instruct|[AI-ModelScope/gemma-7b-it](https://modelscope.cn/models/AI-ModelScope/gemma-7b-it/summary)|q_proj, k_proj, v_proj|gemma|&#x2714;|&#x2714;|transformers>=4.38|-|[google/gemma-7b-it](https://huggingface.co/google/gemma-7b-it)|
|paligemma-3b-pt-224|[AI-ModelScope/paligemma-3b-pt-224](https://modelscope.cn/models/AI-ModelScope/paligemma-3b-pt-224/summary)|q_proj, k_proj, v_proj|paligemma|&#x2714;|&#x2718;|transformers>=4.41|-|[google/paligemma-3b-pt-224](https://huggingface.co/google/paligemma-3b-pt-224)|
|paligemma-3b-pt-448|[AI-ModelScope/paligemma-3b-pt-448](https://modelscope.cn/models/AI-ModelScope/paligemma-3b-pt-448/summary)|q_proj, k_proj, v_proj|paligemma|&#x2714;|&#x2718;|transformers>=4.41|multi-modal, vision|[google/paligemma-3b-pt-448](https://huggingface.co/google/paligemma-3b-pt-448)|
|paligemma-3b-pt-896|[AI-ModelScope/paligemma-3b-pt-896](https://modelscope.cn/models/AI-ModelScope/paligemma-3b-pt-896/summary)|q_proj, k_proj, v_proj|paligemma|&#x2714;|&#x2718;|transformers>=4.41|-|[google/paligemma-3b-pt-896](https://huggingface.co/google/paligemma-3b-pt-896)|
|paligemma-3b-mix-224|[AI-ModelScope/paligemma-3b-mix-224](https://modelscope.cn/models/AI-ModelScope/paligemma-3b-mix-224/summary)|q_proj, k_proj, v_proj|paligemma|&#x2714;|&#x2718;|transformers>=4.41|-|[google/paligemma-3b-mix-224](https://huggingface.co/google/paligemma-3b-mix-224)|
|paligemma-3b-mix-448|[AI-ModelScope/paligemma-3b-mix-448](https://modelscope.cn/models/AI-ModelScope/paligemma-3b-mix-448/summary)|q_proj, k_proj, v_proj|paligemma|&#x2714;|&#x2718;|transformers>=4.41|-|[google/paligemma-3b-mix-448](https://huggingface.co/google/paligemma-3b-mix-448)|
|minicpm-1b-sft-chat|[OpenBMB/MiniCPM-1B-sft-bf16](https://modelscope.cn/models/OpenBMB/MiniCPM-1B-sft-bf16/summary)|q_proj, k_proj, v_proj|minicpm|&#x2714;|&#x2714;|transformers>=4.36.0|-|[openbmb/MiniCPM-1B-sft-bf16](https://huggingface.co/openbmb/MiniCPM-1B-sft-bf16)|
|minicpm-2b-sft-chat|[OpenBMB/MiniCPM-2B-sft-fp32](https://modelscope.cn/models/OpenBMB/MiniCPM-2B-sft-fp32/summary)|q_proj, k_proj, v_proj|minicpm|&#x2714;|&#x2714;||-|[openbmb/MiniCPM-2B-sft-fp32](https://huggingface.co/openbmb/MiniCPM-2B-sft-fp32)|
|minicpm-2b-chat|[OpenBMB/MiniCPM-2B-dpo-fp32](https://modelscope.cn/models/OpenBMB/MiniCPM-2B-dpo-fp32/summary)|q_proj, k_proj, v_proj|minicpm|&#x2714;|&#x2714;||-|[openbmb/MiniCPM-2B-dpo-fp32](https://huggingface.co/openbmb/MiniCPM-2B-dpo-fp32)|
Expand Down
2 changes: 1 addition & 1 deletion docs/source/LLM/自定义与拓展.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
我们支持三种**自定义数据集**的方法.

1. 【推荐】**命令行参数**的形式: **更加方便支持自定义数据集**, 支持四种数据集格式(即使用`SmartPreprocessor`), 支持`dataset_id`和`dataset_path`.
2. 添加数据集到`dataset_info.json`中, 比第一种方式更灵活但繁琐, 支持对数据集使用两种预处理器并指定其参数: `RenameColumnsPreprocessor`, `ConversationsPreprocessor`(默认使用`SmartPreprocessor`). 支持直接修改swift内置的`dataset_info.json`, 或者通过`--dataset_info_path xxx.json`的方式传入外置的json文件(方便pip install而非git clone的用户拓展数据集).
2. 添加数据集到`dataset_info.json`中, 比第一种方式更灵活但繁琐, 支持对数据集使用两种预处理器并指定其参数: `RenameColumnsPreprocessor`, `ConversationsPreprocessor`(默认使用`SmartPreprocessor`). 支持直接修改swift内置的`dataset_info.json`, 或者通过`--custom_dataset_info xxx.json`的方式传入外置的json文件(方便pip install而非git clone的用户拓展数据集).
3. **注册数据集**的方式: 比第1、2种方式更加灵活但繁琐, 支持使用函数对数据集进行预处理. 方法1、2在实现上借助了方法3. 可以直接修改源码进行拓展, 或者通过`--custom_register_path xxx.py`的方式传入, 脚本会对py文件进行解析(方便pip install的用户).

### 📌 【推荐】命令行参数的形式
Expand Down
2 changes: 1 addition & 1 deletion docs/source_en/LLM/Customization.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
We support three methods for **customizing datasets**.

1. \[Recommended\] using command line arguments: It is more convenient to support custom datasets, and it supports four dataset formats (using `SmartPreprocessor`) as well as the `dataset_id` and `dataset_path`.
2. Adding datasets to `dataset_info.json` is more flexible but cumbersome compared to the first method, and supports using two preprocessors and specifying their parameters: `RenameColumnsPreprocessor`, `ConversationsPreprocessor` (default is to use `SmartPreprocessor`). You can directly modify the built-in `dataset_info.json` in Swift, or pass in an external json file using `--dataset_info_path xxx.json` (for users who prefer pip install over git clone to expand datasets).
2. Adding datasets to `dataset_info.json` is more flexible but cumbersome compared to the first method, and supports using two preprocessors and specifying their parameters: `RenameColumnsPreprocessor`, `ConversationsPreprocessor` (default is to use `SmartPreprocessor`). You can directly modify the built-in `dataset_info.json` in Swift, or pass in an external json file using `--custom_dataset_info xxx.json` (for users who prefer pip install over git clone to expand datasets).
3. Registering datasets: More flexible but cumbersome compared to the first and second methods, it supports using functions to preprocess datasets. Methods 1 and 2 are implemented by leveraging method 3. You can directly modify the source code for expansion, or pass in a custom registration path using `--custom_register_path xxx.py`, where the script will parse the py file (for pip install users).

### 📌 \[Recommended\] using Command Line Arguments
Expand Down
5 changes: 5 additions & 0 deletions docs/source_en/LLM/Supported-models-datasets.md
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,11 @@ The table below introcudes all models supported by SWIFT:
|gemma-7b|[AI-ModelScope/gemma-7b](https://modelscope.cn/models/AI-ModelScope/gemma-7b/summary)|q_proj, k_proj, v_proj|default-generation|&#x2714;|&#x2714;|transformers>=4.38|-|[google/gemma-7b](https://huggingface.co/google/gemma-7b)|
|gemma-2b-instruct|[AI-ModelScope/gemma-2b-it](https://modelscope.cn/models/AI-ModelScope/gemma-2b-it/summary)|q_proj, k_proj, v_proj|gemma|&#x2714;|&#x2714;|transformers>=4.38|-|[google/gemma-2b-it](https://huggingface.co/google/gemma-2b-it)|
|gemma-7b-instruct|[AI-ModelScope/gemma-7b-it](https://modelscope.cn/models/AI-ModelScope/gemma-7b-it/summary)|q_proj, k_proj, v_proj|gemma|&#x2714;|&#x2714;|transformers>=4.38|-|[google/gemma-7b-it](https://huggingface.co/google/gemma-7b-it)|
|paligemma-3b-pt-224|[AI-ModelScope/paligemma-3b-pt-224](https://modelscope.cn/models/AI-ModelScope/paligemma-3b-pt-224/summary)|q_proj, k_proj, v_proj|paligemma|&#x2714;|&#x2718;|transformers>=4.41|-|[google/paligemma-3b-pt-224](https://huggingface.co/google/paligemma-3b-pt-224)|
|paligemma-3b-pt-448|[AI-ModelScope/paligemma-3b-pt-448](https://modelscope.cn/models/AI-ModelScope/paligemma-3b-pt-448/summary)|q_proj, k_proj, v_proj|paligemma|&#x2714;|&#x2718;|transformers>=4.41|multi-modal, vision|[google/paligemma-3b-pt-448](https://huggingface.co/google/paligemma-3b-pt-448)|
|paligemma-3b-pt-896|[AI-ModelScope/paligemma-3b-pt-896](https://modelscope.cn/models/AI-ModelScope/paligemma-3b-pt-896/summary)|q_proj, k_proj, v_proj|paligemma|&#x2714;|&#x2718;|transformers>=4.41|-|[google/paligemma-3b-pt-896](https://huggingface.co/google/paligemma-3b-pt-896)|
|paligemma-3b-mix-224|[AI-ModelScope/paligemma-3b-mix-224](https://modelscope.cn/models/AI-ModelScope/paligemma-3b-mix-224/summary)|q_proj, k_proj, v_proj|paligemma|&#x2714;|&#x2718;|transformers>=4.41|-|[google/paligemma-3b-mix-224](https://huggingface.co/google/paligemma-3b-mix-224)|
|paligemma-3b-mix-448|[AI-ModelScope/paligemma-3b-mix-448](https://modelscope.cn/models/AI-ModelScope/paligemma-3b-mix-448/summary)|q_proj, k_proj, v_proj|paligemma|&#x2714;|&#x2718;|transformers>=4.41|-|[google/paligemma-3b-mix-448](https://huggingface.co/google/paligemma-3b-mix-448)|
|minicpm-1b-sft-chat|[OpenBMB/MiniCPM-1B-sft-bf16](https://modelscope.cn/models/OpenBMB/MiniCPM-1B-sft-bf16/summary)|q_proj, k_proj, v_proj|minicpm|&#x2714;|&#x2714;|transformers>=4.36.0|-|[openbmb/MiniCPM-1B-sft-bf16](https://huggingface.co/openbmb/MiniCPM-1B-sft-bf16)|
|minicpm-2b-sft-chat|[OpenBMB/MiniCPM-2B-sft-fp32](https://modelscope.cn/models/OpenBMB/MiniCPM-2B-sft-fp32/summary)|q_proj, k_proj, v_proj|minicpm|&#x2714;|&#x2714;||-|[openbmb/MiniCPM-2B-sft-fp32](https://huggingface.co/openbmb/MiniCPM-2B-sft-fp32)|
|minicpm-2b-chat|[OpenBMB/MiniCPM-2B-dpo-fp32](https://modelscope.cn/models/OpenBMB/MiniCPM-2B-dpo-fp32/summary)|q_proj, k_proj, v_proj|minicpm|&#x2714;|&#x2714;||-|[openbmb/MiniCPM-2B-dpo-fp32](https://huggingface.co/openbmb/MiniCPM-2B-dpo-fp32)|
Expand Down
79 changes: 75 additions & 4 deletions swift/llm/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,12 @@ class ModelType:
gemma_7b = 'gemma-7b'
gemma_2b_instruct = 'gemma-2b-instruct'
gemma_7b_instruct = 'gemma-7b-instruct'
# paligemma
paligemma_3b_pt_224 = 'paligemma-3b-pt-224'
paligemma_3b_pt_448 = 'paligemma-3b-pt-448'
paligemma_3b_pt_896 = 'paligemma-3b-pt-896'
paligemma_3b_mix_224 = 'paligemma-3b-mix-224'
paligemma_3b_mix_448 = 'paligemma-3b-mix-448'
# minicpm
minicpm_1b_sft_chat = 'minicpm-1b-sft-chat'
minicpm_2b_sft_chat = 'minicpm-2b-sft-chat'
Expand Down Expand Up @@ -532,15 +538,15 @@ def _new_forward(self, x):
LoRATM.cogvlm,
TemplateType.cogvlm,
support_gradient_checkpointing=False,
pad_token='<|reserved_special_token_0|>',
placeholder_tokens=['<|reserved_special_token_0|>'],
hf_model_id='THUDM/cogvlm2-llama3-chat-19B')
@register_model(
ModelType.cogvlm2_19b_chat,
'ZhipuAI/cogvlm2-llama3-chinese-chat-19B',
LoRATM.cogvlm,
TemplateType.cogvlm,
support_gradient_checkpointing=False,
pad_token='<|reserved_special_token_0|>',
placeholder_tokens=['<|reserved_special_token_0|>'],
hf_model_id='THUDM/cogvlm2-llama3-chinese-chat-19B')
@register_model(
ModelType.atom_7b,
Expand Down Expand Up @@ -844,6 +850,10 @@ def get_model_tokenizer_from_repo(model_dir: str,
pad_token = kwargs.get('pad_token')
if pad_token is not None:
tokenizer.pad_token = pad_token
placeholder_tokens = kwargs.get('placeholder_tokens')
if placeholder_tokens is not None:
tokenizer.placeholder_tokens = placeholder_tokens
tokenizer.placeholder_tokens_id = [tokenizer.convert_tokens_to_ids(token) for token in placeholder_tokens]
model = None
if load_model:
if kwargs.get('use_unsloth', False):
Expand Down Expand Up @@ -1083,6 +1093,65 @@ def get_model_tokenizer_baichuan_13b(model_dir: str,
return model, tokenizer


@register_model(
ModelType.paligemma_3b_pt_224,
'AI-ModelScope/paligemma-3b-pt-224',
LoRATM.llama2,
TemplateType.paligemma,
support_flash_attn=True,
requires=['transformers>=4.41'],
placeholder_tokens=['<image>'],
hf_model_id='google/paligemma-3b-pt-224')
@register_model(
ModelType.paligemma_3b_pt_448,
'AI-ModelScope/paligemma-3b-pt-448',
LoRATM.llama2,
TemplateType.paligemma,
support_flash_attn=True,
requires=['transformers>=4.41'],
placeholder_tokens=['<image>'],
tags=['multi-modal', 'vision'],
hf_model_id='google/paligemma-3b-pt-448')
@register_model(
ModelType.paligemma_3b_pt_896,
'AI-ModelScope/paligemma-3b-pt-896',
LoRATM.llama2,
TemplateType.paligemma,
support_flash_attn=True,
requires=['transformers>=4.41'],
placeholder_tokens=['<image>'],
hf_model_id='google/paligemma-3b-pt-896')
@register_model(
ModelType.paligemma_3b_mix_224,
'AI-ModelScope/paligemma-3b-mix-224',
LoRATM.llama2,
TemplateType.paligemma,
support_flash_attn=True,
requires=['transformers>=4.41'],
placeholder_tokens=['<image>'],
hf_model_id='google/paligemma-3b-mix-224')
@register_model(
ModelType.paligemma_3b_mix_448,
'AI-ModelScope/paligemma-3b-mix-448',
LoRATM.llama2,
TemplateType.paligemma,
support_flash_attn=True,
requires=['transformers>=4.41'],
placeholder_tokens=['<image>'],
hf_model_id='google/paligemma-3b-mix-448')
def get_model_tokenizer_paligemma_vision(model_dir: str,
torch_dtype: Dtype,
model_kwargs: Dict[str, Any],
load_model: bool = True,
**kwargs):
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
processor = AutoProcessor.from_pretrained(model_dir, trust_remote_code=True)
model, tokenizer = get_model_tokenizer_from_repo(
model_dir, torch_dtype, model_kwargs, load_model, automodel_class=PaliGemmaForConditionalGeneration, **kwargs)
tokenizer.processor = processor
return model, tokenizer


@register_model(
ModelType.phi3_vision_128k_instruct,
'LLM-Research/Phi-3-vision-128k-instruct',
Expand Down Expand Up @@ -2678,7 +2747,6 @@ def get_model_tokenizer_deepseek2(model_dir: str,
if model is not None:
model.generation_config.pad_token_id = model.generation_config.eos_token_id
# fix dtype bug
model.generation_config.pad_token_id = model.generation_config.eos_token_id
mlp_cls = model.model.layers[1].mlp.__class__
for module in model.modules():
if isinstance(module, mlp_cls):
Expand Down Expand Up @@ -4051,7 +4119,7 @@ def _new_forward(*args, **kwargs) -> Tensor:
TemplateType.minicpm_v_v2_5,
support_flash_attn=True,
requires=['timm'],
pad_token='<unk>',
placeholder_tokens=['<unk>'],
function_kwargs={'patching_embedding': True},
hf_model_id='openbmb/MiniCPM-Llama3-V-2_5')
def get_model_tokenizer_minicpm_v(model_dir: str,
Expand Down Expand Up @@ -4396,6 +4464,9 @@ def get_model_tokenizer(model_type: str,
pad_token = model_info.get('pad_token')
if pad_token is not None:
kwargs['pad_token'] = pad_token
placeholder_tokens = model_info.get('placeholder_tokens')
if placeholder_tokens is not None:
kwargs['placeholder_tokens'] = placeholder_tokens
if 'is_training' not in kwargs:
kwargs['is_training'] = False
model, tokenizer = get_function(model_dir, torch_dtype, model_kwargs, load_model, **kwargs)
Expand Down
44 changes: 43 additions & 1 deletion swift/llm/utils/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from swift.llm.agent.utils import calculate_loss_scale
from swift.torchacc_utils import pad_and_split_batch
from swift.utils import get_dist_setting, use_torchacc
from swift.utils import get_dist_setting, upper_bound, use_torchacc

DEFAULT_SYSTEM = 'You are a helpful assistant.'
History = List[Union[Tuple[str, str], List[str]]]
Expand Down Expand Up @@ -70,6 +70,7 @@ class TemplateType:
minicpm_v = 'minicpm-v'
minicpm_v_v2_5 = 'minicpm-v-v2_5'
gemma = 'gemma'
paligemma = 'paligemma'
mplug_owl2 = 'mplug-owl2'
wizardlm2_awq = 'wizardlm2-awq'
wizardlm2 = 'wizardlm2'
Expand Down Expand Up @@ -1044,6 +1045,8 @@ def __init__(self):

def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
inputs, _ = super().encode(example)
if len(inputs) == 0:
return inputs, {}
image_path = example['images']
raw_image = _read_from_path(image_path[0])
pixel_values = self.tokenizer.processor.image_processor(raw_image, return_tensors='pt')['pixel_values']
Expand All @@ -1064,6 +1067,45 @@ def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] =
lazy_tokenize=True)


class PaliGemmaTemplate(Template):

def __init__(self):
Template.__init__(self, ['<bos>'], ['{{QUERY}}\n'], None, ['<eos>'])

def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
inputs, _ = super().encode(example)
image_token = self.tokenizer.encode('<image>', add_special_tokens=False)
assert len(image_token) == 1
image_token = image_token[0]
if len(inputs) == 0:
return inputs, {}
image_path = example['images']
processor = self.tokenizer.processor
inputs['input_ids'] = [image_token] * processor.image_seq_length + inputs['input_ids']
if inputs['labels'] is not None:
n = upper_bound(0, len(inputs['labels']), lambda idx: inputs['labels'][idx] == -100)
n2 = len(inputs['labels']) - n
inputs['labels'] = [-100] * processor.image_seq_length + inputs['labels']
inputs['token_type_ids'] = [0] * (processor.image_seq_length + n) + [1] * n2
else:
inputs['token_type_ids'] = [0] * len(inputs['input_ids'])
raw_image = _read_from_path(image_path[0])
model_inputs = processor(text=example['query'], images=raw_image, return_tensors='pt')
inputs['pixel_values'] = model_inputs['pixel_values']
return inputs, {}

def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] = None) -> Dict[str, Any]:
res = super().data_collator(batch, padding_to)
res['pixel_values'] = torch.concat([b['pixel_values'] for b in batch])
token_type_ids = [torch.tensor(b['token_type_ids']) for b in batch]
token_type_ids = pad_sequence(token_type_ids, batch_first=True, padding_value=0)
res['token_type_ids'] = token_type_ids
return res


register_template(TemplateType.paligemma, PaliGemmaTemplate(), infer_media_type='dialogue', lazy_tokenize=True)


class Phi3VisionTemplate(Template):

def __init__(self):
Expand Down
Loading