Skip to content

Commit 2492338

Browse files
support paligemma (modelscope#1004)
1 parent 4201dc7 commit 2492338

File tree

9 files changed

+136
-11
lines changed

9 files changed

+136
-11
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -535,7 +535,8 @@ The complete list of supported models and datasets can be found at [Supported Mo
535535
| mPLUG-Owl | [mPLUG-Owl series models](https://github.com/X-PLUG/mPLUG-Owl) | English | 11B | chat model |
536536
| InternVL | [InternVL](https://github.com/OpenGVLab/InternVL) | Chinese<br>English | 25.5B<br>including quantized version | chat model |
537537
| Llava-llama3 | [xtuner](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers) | English | 8B | chat model |
538-
| Phi3 | Microsoft | English | 4B | chat model |
538+
| Phi3-Vision | Microsoft | English | 4B | chat model |
539+
| PaliGemma | Google | English | 3B | chat model |
539540

540541
#### Diffusion Models
541542

README_CN.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,8 @@ CUDA_VISIBLE_DEVICES=0 swift deploy \
532532
| mPLUG-Owl | [mPLUG-Owl系列模型](https://github.com/X-PLUG/mPLUG-Owl) | 英文 | 11B | chat模型 |
533533
| InternVL | [InternVL](https://github.com/OpenGVLab/InternVL) | 中文<br>英文 | 25.5B<br>包含量化版本 | chat模型 |
534534
| Llava-llama3 | [xtuner](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers) | 英文 | 8B | chat model |
535-
| Phi3 | Microsoft | 英文 | 4B | chat model |
535+
| Phi3-Vision | 微软 | 英文 | 4B | chat model |
536+
| PaliGemma | Google | 英文 | 3B | chat model |
536537

537538
#### 扩散模型
538539

docs/source/LLM/支持的模型和数据集.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,11 @@
200200
|gemma-7b|[AI-ModelScope/gemma-7b](https://modelscope.cn/models/AI-ModelScope/gemma-7b/summary)|q_proj, k_proj, v_proj|default-generation|&#x2714;|&#x2714;|transformers>=4.38|-|[google/gemma-7b](https://huggingface.co/google/gemma-7b)|
201201
|gemma-2b-instruct|[AI-ModelScope/gemma-2b-it](https://modelscope.cn/models/AI-ModelScope/gemma-2b-it/summary)|q_proj, k_proj, v_proj|gemma|&#x2714;|&#x2714;|transformers>=4.38|-|[google/gemma-2b-it](https://huggingface.co/google/gemma-2b-it)|
202202
|gemma-7b-instruct|[AI-ModelScope/gemma-7b-it](https://modelscope.cn/models/AI-ModelScope/gemma-7b-it/summary)|q_proj, k_proj, v_proj|gemma|&#x2714;|&#x2714;|transformers>=4.38|-|[google/gemma-7b-it](https://huggingface.co/google/gemma-7b-it)|
203+
|paligemma-3b-pt-224|[AI-ModelScope/paligemma-3b-pt-224](https://modelscope.cn/models/AI-ModelScope/paligemma-3b-pt-224/summary)|q_proj, k_proj, v_proj|paligemma|&#x2714;|&#x2718;|transformers>=4.41|-|[google/paligemma-3b-pt-224](https://huggingface.co/google/paligemma-3b-pt-224)|
204+
|paligemma-3b-pt-448|[AI-ModelScope/paligemma-3b-pt-448](https://modelscope.cn/models/AI-ModelScope/paligemma-3b-pt-448/summary)|q_proj, k_proj, v_proj|paligemma|&#x2714;|&#x2718;|transformers>=4.41|multi-modal, vision|[google/paligemma-3b-pt-448](https://huggingface.co/google/paligemma-3b-pt-448)|
205+
|paligemma-3b-pt-896|[AI-ModelScope/paligemma-3b-pt-896](https://modelscope.cn/models/AI-ModelScope/paligemma-3b-pt-896/summary)|q_proj, k_proj, v_proj|paligemma|&#x2714;|&#x2718;|transformers>=4.41|-|[google/paligemma-3b-pt-896](https://huggingface.co/google/paligemma-3b-pt-896)|
206+
|paligemma-3b-mix-224|[AI-ModelScope/paligemma-3b-mix-224](https://modelscope.cn/models/AI-ModelScope/paligemma-3b-mix-224/summary)|q_proj, k_proj, v_proj|paligemma|&#x2714;|&#x2718;|transformers>=4.41|-|[google/paligemma-3b-mix-224](https://huggingface.co/google/paligemma-3b-mix-224)|
207+
|paligemma-3b-mix-448|[AI-ModelScope/paligemma-3b-mix-448](https://modelscope.cn/models/AI-ModelScope/paligemma-3b-mix-448/summary)|q_proj, k_proj, v_proj|paligemma|&#x2714;|&#x2718;|transformers>=4.41|-|[google/paligemma-3b-mix-448](https://huggingface.co/google/paligemma-3b-mix-448)|
203208
|minicpm-1b-sft-chat|[OpenBMB/MiniCPM-1B-sft-bf16](https://modelscope.cn/models/OpenBMB/MiniCPM-1B-sft-bf16/summary)|q_proj, k_proj, v_proj|minicpm|&#x2714;|&#x2714;|transformers>=4.36.0|-|[openbmb/MiniCPM-1B-sft-bf16](https://huggingface.co/openbmb/MiniCPM-1B-sft-bf16)|
204209
|minicpm-2b-sft-chat|[OpenBMB/MiniCPM-2B-sft-fp32](https://modelscope.cn/models/OpenBMB/MiniCPM-2B-sft-fp32/summary)|q_proj, k_proj, v_proj|minicpm|&#x2714;|&#x2714;||-|[openbmb/MiniCPM-2B-sft-fp32](https://huggingface.co/openbmb/MiniCPM-2B-sft-fp32)|
205210
|minicpm-2b-chat|[OpenBMB/MiniCPM-2B-dpo-fp32](https://modelscope.cn/models/OpenBMB/MiniCPM-2B-dpo-fp32/summary)|q_proj, k_proj, v_proj|minicpm|&#x2714;|&#x2714;||-|[openbmb/MiniCPM-2B-dpo-fp32](https://huggingface.co/openbmb/MiniCPM-2B-dpo-fp32)|

docs/source/LLM/自定义与拓展.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
我们支持三种**自定义数据集**的方法.
99

1010
1. 【推荐】**命令行参数**的形式: **更加方便支持自定义数据集**, 支持四种数据集格式(即使用`SmartPreprocessor`), 支持`dataset_id``dataset_path`.
11-
2. 添加数据集到`dataset_info.json`中, 比第一种方式更灵活但繁琐, 支持对数据集使用两种预处理器并指定其参数: `RenameColumnsPreprocessor`, `ConversationsPreprocessor`(默认使用`SmartPreprocessor`). 支持直接修改swift内置的`dataset_info.json`, 或者通过`--dataset_info_path xxx.json`的方式传入外置的json文件(方便pip install而非git clone的用户拓展数据集).
11+
2. 添加数据集到`dataset_info.json`中, 比第一种方式更灵活但繁琐, 支持对数据集使用两种预处理器并指定其参数: `RenameColumnsPreprocessor`, `ConversationsPreprocessor`(默认使用`SmartPreprocessor`). 支持直接修改swift内置的`dataset_info.json`, 或者通过`--custom_dataset_info xxx.json`的方式传入外置的json文件(方便pip install而非git clone的用户拓展数据集).
1212
3. **注册数据集**的方式: 比第1、2种方式更加灵活但繁琐, 支持使用函数对数据集进行预处理. 方法1、2在实现上借助了方法3. 可以直接修改源码进行拓展, 或者通过`--custom_register_path xxx.py`的方式传入, 脚本会对py文件进行解析(方便pip install的用户).
1313

1414
### 📌 【推荐】命令行参数的形式

docs/source_en/LLM/Customization.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
We support three methods for **customizing datasets**.
1010

1111
1. \[Recommended\] using command line arguments: It is more convenient to support custom datasets, and it supports four dataset formats (using `SmartPreprocessor`) as well as the `dataset_id` and `dataset_path`.
12-
2. Adding datasets to `dataset_info.json` is more flexible but cumbersome compared to the first method, and supports using two preprocessors and specifying their parameters: `RenameColumnsPreprocessor`, `ConversationsPreprocessor` (default is to use `SmartPreprocessor`). You can directly modify the built-in `dataset_info.json` in Swift, or pass in an external json file using `--dataset_info_path xxx.json` (for users who prefer pip install over git clone to expand datasets).
12+
2. Adding datasets to `dataset_info.json` is more flexible but cumbersome compared to the first method, and supports using two preprocessors and specifying their parameters: `RenameColumnsPreprocessor`, `ConversationsPreprocessor` (default is to use `SmartPreprocessor`). You can directly modify the built-in `dataset_info.json` in Swift, or pass in an external json file using `--custom_dataset_info xxx.json` (for users who prefer pip install over git clone to expand datasets).
1313
3. Registering datasets: More flexible but cumbersome compared to the first and second methods, it supports using functions to preprocess datasets. Methods 1 and 2 are implemented by leveraging method 3. You can directly modify the source code for expansion, or pass in a custom registration path using `--custom_register_path xxx.py`, where the script will parse the py file (for pip install users).
1414

1515
### 📌 \[Recommended\] using Command Line Arguments

docs/source_en/LLM/Supported-models-datasets.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,11 @@ The table below introcudes all models supported by SWIFT:
200200
|gemma-7b|[AI-ModelScope/gemma-7b](https://modelscope.cn/models/AI-ModelScope/gemma-7b/summary)|q_proj, k_proj, v_proj|default-generation|&#x2714;|&#x2714;|transformers>=4.38|-|[google/gemma-7b](https://huggingface.co/google/gemma-7b)|
201201
|gemma-2b-instruct|[AI-ModelScope/gemma-2b-it](https://modelscope.cn/models/AI-ModelScope/gemma-2b-it/summary)|q_proj, k_proj, v_proj|gemma|&#x2714;|&#x2714;|transformers>=4.38|-|[google/gemma-2b-it](https://huggingface.co/google/gemma-2b-it)|
202202
|gemma-7b-instruct|[AI-ModelScope/gemma-7b-it](https://modelscope.cn/models/AI-ModelScope/gemma-7b-it/summary)|q_proj, k_proj, v_proj|gemma|&#x2714;|&#x2714;|transformers>=4.38|-|[google/gemma-7b-it](https://huggingface.co/google/gemma-7b-it)|
203+
|paligemma-3b-pt-224|[AI-ModelScope/paligemma-3b-pt-224](https://modelscope.cn/models/AI-ModelScope/paligemma-3b-pt-224/summary)|q_proj, k_proj, v_proj|paligemma|&#x2714;|&#x2718;|transformers>=4.41|-|[google/paligemma-3b-pt-224](https://huggingface.co/google/paligemma-3b-pt-224)|
204+
|paligemma-3b-pt-448|[AI-ModelScope/paligemma-3b-pt-448](https://modelscope.cn/models/AI-ModelScope/paligemma-3b-pt-448/summary)|q_proj, k_proj, v_proj|paligemma|&#x2714;|&#x2718;|transformers>=4.41|multi-modal, vision|[google/paligemma-3b-pt-448](https://huggingface.co/google/paligemma-3b-pt-448)|
205+
|paligemma-3b-pt-896|[AI-ModelScope/paligemma-3b-pt-896](https://modelscope.cn/models/AI-ModelScope/paligemma-3b-pt-896/summary)|q_proj, k_proj, v_proj|paligemma|&#x2714;|&#x2718;|transformers>=4.41|-|[google/paligemma-3b-pt-896](https://huggingface.co/google/paligemma-3b-pt-896)|
206+
|paligemma-3b-mix-224|[AI-ModelScope/paligemma-3b-mix-224](https://modelscope.cn/models/AI-ModelScope/paligemma-3b-mix-224/summary)|q_proj, k_proj, v_proj|paligemma|&#x2714;|&#x2718;|transformers>=4.41|-|[google/paligemma-3b-mix-224](https://huggingface.co/google/paligemma-3b-mix-224)|
207+
|paligemma-3b-mix-448|[AI-ModelScope/paligemma-3b-mix-448](https://modelscope.cn/models/AI-ModelScope/paligemma-3b-mix-448/summary)|q_proj, k_proj, v_proj|paligemma|&#x2714;|&#x2718;|transformers>=4.41|-|[google/paligemma-3b-mix-448](https://huggingface.co/google/paligemma-3b-mix-448)|
203208
|minicpm-1b-sft-chat|[OpenBMB/MiniCPM-1B-sft-bf16](https://modelscope.cn/models/OpenBMB/MiniCPM-1B-sft-bf16/summary)|q_proj, k_proj, v_proj|minicpm|&#x2714;|&#x2714;|transformers>=4.36.0|-|[openbmb/MiniCPM-1B-sft-bf16](https://huggingface.co/openbmb/MiniCPM-1B-sft-bf16)|
204209
|minicpm-2b-sft-chat|[OpenBMB/MiniCPM-2B-sft-fp32](https://modelscope.cn/models/OpenBMB/MiniCPM-2B-sft-fp32/summary)|q_proj, k_proj, v_proj|minicpm|&#x2714;|&#x2714;||-|[openbmb/MiniCPM-2B-sft-fp32](https://huggingface.co/openbmb/MiniCPM-2B-sft-fp32)|
205210
|minicpm-2b-chat|[OpenBMB/MiniCPM-2B-dpo-fp32](https://modelscope.cn/models/OpenBMB/MiniCPM-2B-dpo-fp32/summary)|q_proj, k_proj, v_proj|minicpm|&#x2714;|&#x2714;||-|[openbmb/MiniCPM-2B-dpo-fp32](https://huggingface.co/openbmb/MiniCPM-2B-dpo-fp32)|

swift/llm/utils/model.py

Lines changed: 75 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,12 @@ class ModelType:
251251
gemma_7b = 'gemma-7b'
252252
gemma_2b_instruct = 'gemma-2b-instruct'
253253
gemma_7b_instruct = 'gemma-7b-instruct'
254+
# paligemma
255+
paligemma_3b_pt_224 = 'paligemma-3b-pt-224'
256+
paligemma_3b_pt_448 = 'paligemma-3b-pt-448'
257+
paligemma_3b_pt_896 = 'paligemma-3b-pt-896'
258+
paligemma_3b_mix_224 = 'paligemma-3b-mix-224'
259+
paligemma_3b_mix_448 = 'paligemma-3b-mix-448'
254260
# minicpm
255261
minicpm_1b_sft_chat = 'minicpm-1b-sft-chat'
256262
minicpm_2b_sft_chat = 'minicpm-2b-sft-chat'
@@ -532,15 +538,15 @@ def _new_forward(self, x):
532538
LoRATM.cogvlm,
533539
TemplateType.cogvlm,
534540
support_gradient_checkpointing=False,
535-
pad_token='<|reserved_special_token_0|>',
541+
placeholder_tokens=['<|reserved_special_token_0|>'],
536542
hf_model_id='THUDM/cogvlm2-llama3-chat-19B')
537543
@register_model(
538544
ModelType.cogvlm2_19b_chat,
539545
'ZhipuAI/cogvlm2-llama3-chinese-chat-19B',
540546
LoRATM.cogvlm,
541547
TemplateType.cogvlm,
542548
support_gradient_checkpointing=False,
543-
pad_token='<|reserved_special_token_0|>',
549+
placeholder_tokens=['<|reserved_special_token_0|>'],
544550
hf_model_id='THUDM/cogvlm2-llama3-chinese-chat-19B')
545551
@register_model(
546552
ModelType.atom_7b,
@@ -844,6 +850,10 @@ def get_model_tokenizer_from_repo(model_dir: str,
844850
pad_token = kwargs.get('pad_token')
845851
if pad_token is not None:
846852
tokenizer.pad_token = pad_token
853+
placeholder_tokens = kwargs.get('placeholder_tokens')
854+
if placeholder_tokens is not None:
855+
tokenizer.placeholder_tokens = placeholder_tokens
856+
tokenizer.placeholder_tokens_id = [tokenizer.convert_tokens_to_ids(token) for token in placeholder_tokens]
847857
model = None
848858
if load_model:
849859
if kwargs.get('use_unsloth', False):
@@ -1083,6 +1093,65 @@ def get_model_tokenizer_baichuan_13b(model_dir: str,
10831093
return model, tokenizer
10841094

10851095

1096+
@register_model(
1097+
ModelType.paligemma_3b_pt_224,
1098+
'AI-ModelScope/paligemma-3b-pt-224',
1099+
LoRATM.llama2,
1100+
TemplateType.paligemma,
1101+
support_flash_attn=True,
1102+
requires=['transformers>=4.41'],
1103+
placeholder_tokens=['<image>'],
1104+
hf_model_id='google/paligemma-3b-pt-224')
1105+
@register_model(
1106+
ModelType.paligemma_3b_pt_448,
1107+
'AI-ModelScope/paligemma-3b-pt-448',
1108+
LoRATM.llama2,
1109+
TemplateType.paligemma,
1110+
support_flash_attn=True,
1111+
requires=['transformers>=4.41'],
1112+
placeholder_tokens=['<image>'],
1113+
tags=['multi-modal', 'vision'],
1114+
hf_model_id='google/paligemma-3b-pt-448')
1115+
@register_model(
1116+
ModelType.paligemma_3b_pt_896,
1117+
'AI-ModelScope/paligemma-3b-pt-896',
1118+
LoRATM.llama2,
1119+
TemplateType.paligemma,
1120+
support_flash_attn=True,
1121+
requires=['transformers>=4.41'],
1122+
placeholder_tokens=['<image>'],
1123+
hf_model_id='google/paligemma-3b-pt-896')
1124+
@register_model(
1125+
ModelType.paligemma_3b_mix_224,
1126+
'AI-ModelScope/paligemma-3b-mix-224',
1127+
LoRATM.llama2,
1128+
TemplateType.paligemma,
1129+
support_flash_attn=True,
1130+
requires=['transformers>=4.41'],
1131+
placeholder_tokens=['<image>'],
1132+
hf_model_id='google/paligemma-3b-mix-224')
1133+
@register_model(
1134+
ModelType.paligemma_3b_mix_448,
1135+
'AI-ModelScope/paligemma-3b-mix-448',
1136+
LoRATM.llama2,
1137+
TemplateType.paligemma,
1138+
support_flash_attn=True,
1139+
requires=['transformers>=4.41'],
1140+
placeholder_tokens=['<image>'],
1141+
hf_model_id='google/paligemma-3b-mix-448')
1142+
def get_model_tokenizer_paligemma_vision(model_dir: str,
1143+
torch_dtype: Dtype,
1144+
model_kwargs: Dict[str, Any],
1145+
load_model: bool = True,
1146+
**kwargs):
1147+
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
1148+
processor = AutoProcessor.from_pretrained(model_dir, trust_remote_code=True)
1149+
model, tokenizer = get_model_tokenizer_from_repo(
1150+
model_dir, torch_dtype, model_kwargs, load_model, automodel_class=PaliGemmaForConditionalGeneration, **kwargs)
1151+
tokenizer.processor = processor
1152+
return model, tokenizer
1153+
1154+
10861155
@register_model(
10871156
ModelType.phi3_vision_128k_instruct,
10881157
'LLM-Research/Phi-3-vision-128k-instruct',
@@ -2678,7 +2747,6 @@ def get_model_tokenizer_deepseek2(model_dir: str,
26782747
if model is not None:
26792748
model.generation_config.pad_token_id = model.generation_config.eos_token_id
26802749
# fix dtype bug
2681-
model.generation_config.pad_token_id = model.generation_config.eos_token_id
26822750
mlp_cls = model.model.layers[1].mlp.__class__
26832751
for module in model.modules():
26842752
if isinstance(module, mlp_cls):
@@ -4051,7 +4119,7 @@ def _new_forward(*args, **kwargs) -> Tensor:
40514119
TemplateType.minicpm_v_v2_5,
40524120
support_flash_attn=True,
40534121
requires=['timm'],
4054-
pad_token='<unk>',
4122+
placeholder_tokens=['<unk>'],
40554123
function_kwargs={'patching_embedding': True},
40564124
hf_model_id='openbmb/MiniCPM-Llama3-V-2_5')
40574125
def get_model_tokenizer_minicpm_v(model_dir: str,
@@ -4396,6 +4464,9 @@ def get_model_tokenizer(model_type: str,
43964464
pad_token = model_info.get('pad_token')
43974465
if pad_token is not None:
43984466
kwargs['pad_token'] = pad_token
4467+
placeholder_tokens = model_info.get('placeholder_tokens')
4468+
if placeholder_tokens is not None:
4469+
kwargs['placeholder_tokens'] = placeholder_tokens
43994470
if 'is_training' not in kwargs:
44004471
kwargs['is_training'] = False
44014472
model, tokenizer = get_function(model_dir, torch_dtype, model_kwargs, load_model, **kwargs)

swift/llm/utils/template.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from swift.llm.agent.utils import calculate_loss_scale
1616
from swift.torchacc_utils import pad_and_split_batch
17-
from swift.utils import get_dist_setting, use_torchacc
17+
from swift.utils import get_dist_setting, upper_bound, use_torchacc
1818

1919
DEFAULT_SYSTEM = 'You are a helpful assistant.'
2020
History = List[Union[Tuple[str, str], List[str]]]
@@ -70,6 +70,7 @@ class TemplateType:
7070
minicpm_v = 'minicpm-v'
7171
minicpm_v_v2_5 = 'minicpm-v-v2_5'
7272
gemma = 'gemma'
73+
paligemma = 'paligemma'
7374
mplug_owl2 = 'mplug-owl2'
7475
wizardlm2_awq = 'wizardlm2-awq'
7576
wizardlm2 = 'wizardlm2'
@@ -1044,6 +1045,8 @@ def __init__(self):
10441045

10451046
def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
10461047
inputs, _ = super().encode(example)
1048+
if len(inputs) == 0:
1049+
return inputs, {}
10471050
image_path = example['images']
10481051
raw_image = _read_from_path(image_path[0])
10491052
pixel_values = self.tokenizer.processor.image_processor(raw_image, return_tensors='pt')['pixel_values']
@@ -1064,6 +1067,45 @@ def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] =
10641067
lazy_tokenize=True)
10651068

10661069

1070+
class PaliGemmaTemplate(Template):
1071+
1072+
def __init__(self):
1073+
Template.__init__(self, ['<bos>'], ['{{QUERY}}\n'], None, ['<eos>'])
1074+
1075+
def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
1076+
inputs, _ = super().encode(example)
1077+
image_token = self.tokenizer.encode('<image>', add_special_tokens=False)
1078+
assert len(image_token) == 1
1079+
image_token = image_token[0]
1080+
if len(inputs) == 0:
1081+
return inputs, {}
1082+
image_path = example['images']
1083+
processor = self.tokenizer.processor
1084+
inputs['input_ids'] = [image_token] * processor.image_seq_length + inputs['input_ids']
1085+
if inputs['labels'] is not None:
1086+
n = upper_bound(0, len(inputs['labels']), lambda idx: inputs['labels'][idx] == -100)
1087+
n2 = len(inputs['labels']) - n
1088+
inputs['labels'] = [-100] * processor.image_seq_length + inputs['labels']
1089+
inputs['token_type_ids'] = [0] * (processor.image_seq_length + n) + [1] * n2
1090+
else:
1091+
inputs['token_type_ids'] = [0] * len(inputs['input_ids'])
1092+
raw_image = _read_from_path(image_path[0])
1093+
model_inputs = processor(text=example['query'], images=raw_image, return_tensors='pt')
1094+
inputs['pixel_values'] = model_inputs['pixel_values']
1095+
return inputs, {}
1096+
1097+
def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] = None) -> Dict[str, Any]:
1098+
res = super().data_collator(batch, padding_to)
1099+
res['pixel_values'] = torch.concat([b['pixel_values'] for b in batch])
1100+
token_type_ids = [torch.tensor(b['token_type_ids']) for b in batch]
1101+
token_type_ids = pad_sequence(token_type_ids, batch_first=True, padding_value=0)
1102+
res['token_type_ids'] = token_type_ids
1103+
return res
1104+
1105+
1106+
register_template(TemplateType.paligemma, PaliGemmaTemplate(), infer_media_type='dialogue', lazy_tokenize=True)
1107+
1108+
10671109
class Phi3VisionTemplate(Template):
10681110

10691111
def __init__(self):

0 commit comments

Comments
 (0)