Skip to content

Commit 2b63855

Browse files
committed
Compare has better downloader
1 parent 0b40bd3 commit 2b63855

File tree

1 file changed

+41
-30
lines changed

1 file changed

+41
-30
lines changed

olmocr/train/compare_vllm_checkpoint.py

Lines changed: 41 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -70,46 +70,57 @@ async def load_pdf_prompts(num_samples: int = 100, seed: int = 42, max_length: i
7070
random.seed(seed)
7171
np.random.seed(seed)
7272

73-
# Download dataset to a temporary directory
73+
# Import huggingface_hub utilities to list files
74+
from huggingface_hub import list_repo_files, hf_hub_download
75+
76+
# List all PDF files in the repository
77+
print("Listing PDF files in dataset...")
78+
all_files = list_repo_files(
79+
repo_id="allenai/olmOCR-mix-0225-benchmarkset",
80+
repo_type="dataset"
81+
)
82+
83+
# Filter for PDF files in the pdfs directory
84+
pdf_files = [f for f in all_files if f.startswith("pdfs/") and f.endswith(".pdf")]
85+
86+
if not pdf_files:
87+
raise ValueError("No PDF files found in the dataset")
88+
89+
print(f"Found {len(pdf_files)} PDF files in dataset")
90+
91+
# Randomly sample num_samples PDFs
92+
if len(pdf_files) > num_samples:
93+
sampled_pdf_files = random.sample(pdf_files, num_samples)
94+
else:
95+
sampled_pdf_files = pdf_files
96+
print(f"Warning: Only {len(pdf_files)} PDFs available, less than requested {num_samples}")
97+
98+
print(f"Sampled {len(sampled_pdf_files)} PDFs to download")
99+
100+
# Download only the sampled PDFs and process them
101+
queries = []
74102
with tempfile.TemporaryDirectory() as temp_dir:
75-
print("Downloading dataset...")
76-
dataset_path = snapshot_download(
77-
repo_id="allenai/olmOCR-mix-0225-benchmarkset",
78-
repo_type="dataset",
79-
local_dir=temp_dir,
80-
allow_patterns="pdfs/*.pdf" # Only download PDF files
81-
)
82-
83-
# Find all PDF files in the pdfs directory
84-
pdf_pattern = os.path.join(dataset_path, "pdfs", "*.pdf")
85-
pdf_files = glob.glob(pdf_pattern)
86-
87-
if not pdf_files:
88-
raise ValueError(f"No PDF files found in {pdf_pattern}")
89-
90-
print(f"Found {len(pdf_files)} PDF files")
91-
92-
# Randomly sample num_samples PDFs
93-
if len(pdf_files) > num_samples:
94-
sampled_pdfs = random.sample(pdf_files, num_samples)
95-
else:
96-
sampled_pdfs = pdf_files
97-
print(f"Warning: Only {len(pdf_files)} PDFs available, less than requested {num_samples}")
98-
99-
# Process each PDF and build queries
100-
queries = []
101-
for pdf_path in sampled_pdfs:
103+
for pdf_file in sampled_pdf_files:
102104
try:
105+
# Download individual PDF file
106+
print(f"Downloading {pdf_file}...")
107+
local_pdf_path = hf_hub_download(
108+
repo_id="allenai/olmOCR-mix-0225-benchmarkset",
109+
filename=pdf_file,
110+
repo_type="dataset",
111+
local_dir=temp_dir
112+
)
113+
103114
# Build page query for page 1 of each PDF
104115
query = await build_page_query(
105-
local_pdf_path=pdf_path,
116+
local_pdf_path=local_pdf_path,
106117
page=1,
107118
target_longest_image_dim=1280,
108119
image_rotation=0
109120
)
110121
queries.append(query)
111122
except Exception as e:
112-
print(f"Error processing {os.path.basename(pdf_path)}: {e}")
123+
print(f"Error processing {os.path.basename(pdf_file)}: {e}")
113124
continue
114125

115126
print(f"Successfully processed {len(queries)} PDFs")

0 commit comments

Comments
 (0)