|
8 | 8 | from transformers import (
|
9 | 9 | AutoProcessor,
|
10 | 10 | Qwen2VLForConditionalGeneration,
|
| 11 | + Qwen2_5_VLForConditionalGeneration, |
11 | 12 | Trainer,
|
12 | 13 | TrainingArguments,
|
13 | 14 | EarlyStoppingCallback
|
@@ -92,13 +93,24 @@ def main():
|
92 | 93 |
|
93 | 94 | # Load model
|
94 | 95 | logger.info(f"Loading model: {config.model.name}")
|
95 |
| - model = Qwen2VLForConditionalGeneration.from_pretrained( |
96 |
| - config.model.name, |
97 |
| - torch_dtype=getattr(torch, config.model.torch_dtype) if config.model.torch_dtype != "auto" else "auto", |
98 |
| - device_map=config.model.device_map, |
99 |
| - trust_remote_code=config.model.trust_remote_code, |
100 |
| - attn_implementation=config.model.attn_implementation if config.model.use_flash_attention else None, |
101 |
| - ) |
| 96 | + if "Qwen2.5-VL" in config.model.name: |
| 97 | + model = Qwen2_5_VLForConditionalGeneration.from_pretrained( |
| 98 | + config.model.name, |
| 99 | + torch_dtype=getattr(torch, config.model.torch_dtype) if config.model.torch_dtype != "auto" else "auto", |
| 100 | + device_map=config.model.device_map, |
| 101 | + trust_remote_code=config.model.trust_remote_code, |
| 102 | + attn_implementation=config.model.attn_implementation if config.model.use_flash_attention else None, |
| 103 | + ) |
| 104 | + elif "Qwen2-VL" in config.model.name: |
| 105 | + model = Qwen2VLForConditionalGeneration.from_pretrained( |
| 106 | + config.model.name, |
| 107 | + torch_dtype=getattr(torch, config.model.torch_dtype) if config.model.torch_dtype != "auto" else "auto", |
| 108 | + device_map=config.model.device_map, |
| 109 | + trust_remote_code=config.model.trust_remote_code, |
| 110 | + attn_implementation=config.model.attn_implementation if config.model.use_flash_attention else None, |
| 111 | + ) |
| 112 | + else: |
| 113 | + raise NotImplementedError() |
102 | 114 |
|
103 | 115 | # Enable gradient checkpointing if configured
|
104 | 116 | if config.training.gradient_checkpointing:
|
|
0 commit comments