1
+ import multiprocessing
2
+ import threading
3
+ from multiprocessing import Queue
4
+ # from arrayqueues.shared_arrays import ArrayQueue
5
+ # from faster_fifo import Queue
6
+ import time
7
+ import albumentations as A
8
+ import queue
9
+ import cv2
10
+ from functools import partial
11
+ from typing import Any , Dict , List , Tuple
12
+
13
+ import numpy as np
14
+ from functools import partial
15
+
16
+ from datasets import load_dataset , concatenate_datasets , Dataset
17
+ from datasets .utils .file_utils import get_datasets_user_agent
18
+ from concurrent .futures import ThreadPoolExecutor
19
+ import io
20
+ import urllib
21
+
22
+ import PIL .Image
23
+ import cv2
24
+
25
+ USER_AGENT = get_datasets_user_agent ()
26
+
27
+ data_queue = Queue (16 * 2000 )
28
+ error_queue = Queue (16 * 2000 )
29
+
30
+
31
+ def fetch_single_image (image_url , timeout = None , retries = 0 ):
32
+ for _ in range (retries + 1 ):
33
+ try :
34
+ request = urllib .request .Request (
35
+ image_url ,
36
+ data = None ,
37
+ headers = {"user-agent" : USER_AGENT },
38
+ )
39
+ with urllib .request .urlopen (request , timeout = timeout ) as req :
40
+ image = PIL .Image .open (io .BytesIO (req .read ()))
41
+ break
42
+ except Exception :
43
+ image = None
44
+ return image
45
+
46
+ def map_sample (
47
+ url , caption ,
48
+ image_shape = (256 , 256 ),
49
+ upscale_interpolation = cv2 .INTER_LANCZOS4 ,
50
+ downscale_interpolation = cv2 .INTER_AREA ,
51
+ ):
52
+ try :
53
+ image = fetch_single_image (url , timeout = 15 , retries = 3 ) # Assuming fetch_single_image is defined elsewhere
54
+ if image is None :
55
+ return
56
+
57
+ image = np .array (image )
58
+ original_height , original_width = image .shape [:2 ]
59
+ # check if the image is too small
60
+ if min (original_height , original_width ) < min (image_shape ):
61
+ return
62
+ # check if wrong aspect ratio
63
+ if max (original_height , original_width ) / min (original_height , original_width ) > 2 :
64
+ return
65
+ # check if the variance is too low
66
+ if np .std (image ) < 1e-4 :
67
+ return
68
+ image = cv2 .cvtColor (image , cv2 .COLOR_BGR2RGB )
69
+ downscale = max (original_width , original_height ) > max (image_shape )
70
+ interpolation = downscale_interpolation if downscale else upscale_interpolation
71
+ image = A .longest_max_size (image , max (image_shape ), interpolation = interpolation )
72
+ image = A .pad (
73
+ image ,
74
+ min_height = image_shape [0 ],
75
+ min_width = image_shape [1 ],
76
+ border_mode = cv2 .BORDER_CONSTANT ,
77
+ value = [255 , 255 , 255 ],
78
+ )
79
+ data_queue .put ({
80
+ "url" : url ,
81
+ "caption" : caption ,
82
+ "image" : image
83
+ })
84
+ except Exception as e :
85
+ error_queue .put ({
86
+ "url" : url ,
87
+ "caption" : caption ,
88
+ "error" : str (e )
89
+ })
90
+
91
+ def map_batch (batch , num_threads = 256 , timeout = None , retries = 0 ):
92
+ with ThreadPoolExecutor (max_workers = num_threads ) as executor :
93
+ executor .map (map_sample , batch ["url" ], batch ['caption' ])
94
+
95
+ def parallel_image_loader (dataset : Dataset , num_workers : int = 8 , num_threads = 256 ):
96
+ map_batch_fn = partial (map_batch , num_threads = num_threads )
97
+ shard_len = len (dataset ) // num_workers
98
+ print (f"Local Shard lengths: { shard_len } " )
99
+ with multiprocessing .Pool (num_workers ) as pool :
100
+ iteration = 0
101
+ while True :
102
+ # Repeat forever
103
+ dataset = dataset .shuffle (seed = iteration )
104
+ shards = [dataset [i * shard_len :(i + 1 )* shard_len ] for i in range (num_workers )]
105
+ pool .map (map_batch_fn , shards )
106
+ iteration += 1
107
+
108
+ class ImageBatchIterator :
109
+ def __init__ (self , dataset : Dataset , batch_size : int = 64 , num_workers : int = 8 , num_threads = 256 ):
110
+ self .dataset = dataset
111
+ self .num_workers = num_workers
112
+ self .batch_size = batch_size
113
+ loader = partial (parallel_image_loader , num_threads = num_threads )
114
+ self .thread = threading .Thread (target = loader , args = (dataset , num_workers ))
115
+ self .thread .start ()
116
+
117
+ def __iter__ (self ):
118
+ return self
119
+
120
+ def __next__ (self ):
121
+ def fetcher (_ ):
122
+ return data_queue .get ()
123
+ with ThreadPoolExecutor (max_workers = self .batch_size ) as executor :
124
+ batch = list (executor .map (fetcher , range (self .batch_size )))
125
+ return batch
126
+
127
+ def __del__ (self ):
128
+ self .thread .join ()
129
+
130
+ def __len__ (self ):
131
+ return len (self .dataset ) // self .batch_size
132
+
133
+ def default_collate (batch ):
134
+ urls = [sample ["url" ] for sample in batch ]
135
+ captions = [sample ["caption" ] for sample in batch ]
136
+ images = np .stack ([sample ["image" ] for sample in batch ], axis = 0 )
137
+ return {
138
+ "url" : urls ,
139
+ "caption" : captions ,
140
+ "image" : images ,
141
+ }
142
+
143
+ def dataMapper (map : Dict [str , Any ]):
144
+ def _map (sample ) -> Dict [str , Any ]:
145
+ return {
146
+ "url" : sample [map ["url" ]],
147
+ "caption" : sample [map ["caption" ]],
148
+ }
149
+ return _map
150
+
151
+ class OnlineStreamingDataLoader ():
152
+ def __init__ (
153
+ self ,
154
+ dataset ,
155
+ batch_size = 64 ,
156
+ num_workers = 16 ,
157
+ num_threads = 512 ,
158
+ default_split = "all" ,
159
+ pre_map_maker = dataMapper ,
160
+ pre_map_def = {
161
+ "url" : "URL" ,
162
+ "caption" : "TEXT" ,
163
+ },
164
+ global_process_count = 1 ,
165
+ global_process_index = 0 ,
166
+ prefetch = 1000 ,
167
+ collate_fn = default_collate ,
168
+ ):
169
+ if isinstance (dataset , str ):
170
+ dataset_path = dataset
171
+ print ("Loading dataset from path" )
172
+ dataset = load_dataset (dataset_path , split = default_split )
173
+ elif isinstance (dataset , list ):
174
+ if isinstance (dataset [0 ], str ):
175
+ print ("Loading multiple datasets from paths" )
176
+ dataset = [load_dataset (dataset_path , split = default_split ) for dataset_path in dataset ]
177
+ else :
178
+ print ("Concatenating multiple datasets" )
179
+ dataset = concatenate_datasets (dataset )
180
+ dataset = dataset .map (pre_map_maker (pre_map_def ))
181
+ self .dataset = dataset .shard (num_shards = global_process_count , index = global_process_index )
182
+ print (f"Dataset length: { len (dataset )} " )
183
+ self .iterator = ImageBatchIterator (self .dataset , num_workers = num_workers , batch_size = batch_size , num_threads = num_threads )
184
+ self .collate_fn = collate_fn
185
+
186
+ # Launch a thread to load batches in the background
187
+ self .batch_queue = queue .Queue (prefetch )
188
+
189
+ def batch_loader ():
190
+ for batch in self .iterator :
191
+ self .batch_queue .put (batch )
192
+
193
+ self .loader_thread = threading .Thread (target = batch_loader )
194
+ self .loader_thread .start ()
195
+
196
+ def __iter__ (self ):
197
+ return self
198
+
199
+ def __next__ (self ):
200
+ return self .collate_fn (self .batch_queue .get ())
201
+ # return self.collate_fn(next(self.iterator))
202
+
203
+ def __len__ (self ):
204
+ return len (self .dataset ) // self .batch_size
205
+
0 commit comments