Skip to content

Commit 850b598

Browse files
committed
Sdpa
1 parent b96454b commit 850b598

File tree

2 files changed

+20
-8
lines changed

2 files changed

+20
-8
lines changed

olmocr/train/configs/example_config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ model:
1010
trust_remote_code: true
1111
torch_dtype: auto
1212
use_flash_attention: true
13-
attn_implementation: flash_attention_2
13+
attn_implementation: sdpa
1414

1515
# LoRA settings (disabled by default)
1616
use_lora: false

olmocr/train/train.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from transformers import (
99
AutoProcessor,
1010
Qwen2VLForConditionalGeneration,
11+
Qwen2_5_VLForConditionalGeneration,
1112
Trainer,
1213
TrainingArguments,
1314
EarlyStoppingCallback
@@ -92,13 +93,24 @@ def main():
9293

9394
# Load model
9495
logger.info(f"Loading model: {config.model.name}")
95-
model = Qwen2VLForConditionalGeneration.from_pretrained(
96-
config.model.name,
97-
torch_dtype=getattr(torch, config.model.torch_dtype) if config.model.torch_dtype != "auto" else "auto",
98-
device_map=config.model.device_map,
99-
trust_remote_code=config.model.trust_remote_code,
100-
attn_implementation=config.model.attn_implementation if config.model.use_flash_attention else None,
101-
)
96+
if "Qwen2.5-VL" in config.model.name:
97+
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
98+
config.model.name,
99+
torch_dtype=getattr(torch, config.model.torch_dtype) if config.model.torch_dtype != "auto" else "auto",
100+
device_map=config.model.device_map,
101+
trust_remote_code=config.model.trust_remote_code,
102+
attn_implementation=config.model.attn_implementation if config.model.use_flash_attention else None,
103+
)
104+
elif "Qwen2-VL" in config.model.name:
105+
model = Qwen2VLForConditionalGeneration.from_pretrained(
106+
config.model.name,
107+
torch_dtype=getattr(torch, config.model.torch_dtype) if config.model.torch_dtype != "auto" else "auto",
108+
device_map=config.model.device_map,
109+
trust_remote_code=config.model.trust_remote_code,
110+
attn_implementation=config.model.attn_implementation if config.model.use_flash_attention else None,
111+
)
112+
else:
113+
raise NotImplementedError()
102114

103115
# Enable gradient checkpointing if configured
104116
if config.training.gradient_checkpointing:

0 commit comments

Comments
 (0)