File tree Expand file tree Collapse file tree 4 files changed +12
-6
lines changed Expand file tree Collapse file tree 4 files changed +12
-6
lines changed Original file line number Diff line number Diff line change 10
10
trust_remote_code : true
11
11
torch_dtype : auto
12
12
use_flash_attention : true
13
- attn_implementation : sdpa
13
+ attn_implementation : flash_attention_2
14
14
15
15
# LoRA settings (disabled by default)
16
16
use_lora : false
@@ -65,9 +65,11 @@ training:
65
65
per_device_train_batch_size : 1
66
66
per_device_eval_batch_size : 1
67
67
gradient_accumulation_steps : 8
68
+
69
+ gradient_checkpointing : False
68
70
69
71
# Learning rate
70
- learning_rate : 2e-5
72
+ learning_rate : 1e-6
71
73
lr_scheduler_type : cosine
72
74
warmup_ratio : 0.1
73
75
Original file line number Diff line number Diff line change @@ -162,9 +162,13 @@ def main():
162
162
total_eval_samples = sum (len (dataset ) for dataset in eval_datasets .values ())
163
163
logger .info (f"Total evaluation samples across { len (eval_datasets )} datasets: { total_eval_samples } " )
164
164
165
+ # Construct full output directory by appending run_name to base output_dir
166
+ full_output_dir = os .path .join (config .training .output_dir , config .run_name )
167
+ logger .info (f"Setting output directory to: { full_output_dir } " )
168
+
165
169
# Set up training arguments
166
170
training_args = TrainingArguments (
167
- output_dir = config . training . output_dir ,
171
+ output_dir = full_output_dir ,
168
172
num_train_epochs = config .training .num_train_epochs ,
169
173
per_device_train_batch_size = config .training .per_device_train_batch_size ,
170
174
per_device_eval_batch_size = config .training .per_device_eval_batch_size ,
Original file line number Diff line number Diff line change @@ -37,7 +37,7 @@ dependencies = [
37
37
" boto3" ,
38
38
" httpx" ,
39
39
" torch>=2.7.0" ,
40
- " transformers>=4.51.1 " ,
40
+ " transformers==4.52.4 " ,
41
41
" img2pdf" ,
42
42
" beaker-py" ,
43
43
]
Original file line number Diff line number Diff line change @@ -52,7 +52,7 @@ gantry run \
52
52
--priority normal \
53
53
--gpus 1 \
54
54
--preemptible \
55
- --cluster " ai2/jupiter -cirrascale-2 " \
55
+ --cluster " ai2/titan -cirrascale" \
56
56
--budget ai2/oe-data \
57
57
--env LOG_FILTER_TYPE=local_rank0_only \
58
58
--env OMP_NUM_THREADS=8 \
@@ -64,4 +64,4 @@ gantry run \
64
64
--weka oe-training-default:/weka/oe-training-default \
65
65
--shared-memory 10GiB \
66
66
--yes \
67
- -- /bin/bash -c " source scripts/beaker/jupiter-ib.sh && python -m olmocr.train.train --config olmocr/train/configs/example_config.yaml"
67
+ -- /bin/bash -c " pip install flash-attn==2.8.0.post2 --no-build-isolation && python -m olmocr.train.train --config olmocr/train/configs/example_config.yaml"
You can’t perform that action at this time.
0 commit comments