Skip to content

Commit 6244b6a

Browse files
committed
feat: new online streaming dataloader and VAE dtype conf
1 parent 3cacef5 commit 6244b6a

File tree

6 files changed

+458
-225
lines changed

6 files changed

+458
-225
lines changed

datasets/dataset preparations.ipynb

Lines changed: 103 additions & 166 deletions
Large diffs are not rendered by default.

evaluate.ipynb

Lines changed: 146 additions & 55 deletions
Large diffs are not rendered by default.

flaxdiff/data/__init__.py

Whitespace-only changes.

flaxdiff/data/online_loader.py

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
import multiprocessing
2+
import threading
3+
from multiprocessing import Queue
4+
# from arrayqueues.shared_arrays import ArrayQueue
5+
# from faster_fifo import Queue
6+
import time
7+
import albumentations as A
8+
import queue
9+
import cv2
10+
from functools import partial
11+
from typing import Any, Dict, List, Tuple
12+
13+
import numpy as np
14+
from functools import partial
15+
16+
from datasets import load_dataset, concatenate_datasets, Dataset
17+
from datasets.utils.file_utils import get_datasets_user_agent
18+
from concurrent.futures import ThreadPoolExecutor
19+
import io
20+
import urllib
21+
22+
import PIL.Image
23+
import cv2
24+
25+
USER_AGENT = get_datasets_user_agent()
26+
27+
data_queue = Queue(16*2000)
28+
error_queue = Queue(16*2000)
29+
30+
31+
def fetch_single_image(image_url, timeout=None, retries=0):
32+
for _ in range(retries + 1):
33+
try:
34+
request = urllib.request.Request(
35+
image_url,
36+
data=None,
37+
headers={"user-agent": USER_AGENT},
38+
)
39+
with urllib.request.urlopen(request, timeout=timeout) as req:
40+
image = PIL.Image.open(io.BytesIO(req.read()))
41+
break
42+
except Exception:
43+
image = None
44+
return image
45+
46+
def map_sample(
47+
url, caption,
48+
image_shape=(256, 256),
49+
upscale_interpolation=cv2.INTER_LANCZOS4,
50+
downscale_interpolation=cv2.INTER_AREA,
51+
):
52+
try:
53+
image = fetch_single_image(url, timeout=15, retries=3) # Assuming fetch_single_image is defined elsewhere
54+
if image is None:
55+
return
56+
57+
image = np.array(image)
58+
original_height, original_width = image.shape[:2]
59+
# check if the image is too small
60+
if min(original_height, original_width) < min(image_shape):
61+
return
62+
# check if wrong aspect ratio
63+
if max(original_height, original_width) / min(original_height, original_width) > 2:
64+
return
65+
# check if the variance is too low
66+
if np.std(image) < 1e-4:
67+
return
68+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
69+
downscale = max(original_width, original_height) > max(image_shape)
70+
interpolation = downscale_interpolation if downscale else upscale_interpolation
71+
image = A.longest_max_size(image, max(image_shape), interpolation=interpolation)
72+
image = A.pad(
73+
image,
74+
min_height=image_shape[0],
75+
min_width=image_shape[1],
76+
border_mode=cv2.BORDER_CONSTANT,
77+
value=[255, 255, 255],
78+
)
79+
data_queue.put({
80+
"url": url,
81+
"caption": caption,
82+
"image": image
83+
})
84+
except Exception as e:
85+
error_queue.put({
86+
"url": url,
87+
"caption": caption,
88+
"error": str(e)
89+
})
90+
91+
def map_batch(batch, num_threads=256, timeout=None, retries=0):
92+
with ThreadPoolExecutor(max_workers=num_threads) as executor:
93+
executor.map(map_sample, batch["url"], batch['caption'])
94+
95+
def parallel_image_loader(dataset: Dataset, num_workers: int = 8, num_threads=256):
96+
map_batch_fn = partial(map_batch, num_threads=num_threads)
97+
shard_len = len(dataset) // num_workers
98+
print(f"Local Shard lengths: {shard_len}")
99+
with multiprocessing.Pool(num_workers) as pool:
100+
iteration = 0
101+
while True:
102+
# Repeat forever
103+
dataset = dataset.shuffle(seed=iteration)
104+
shards = [dataset[i*shard_len:(i+1)*shard_len] for i in range(num_workers)]
105+
pool.map(map_batch_fn, shards)
106+
iteration += 1
107+
108+
class ImageBatchIterator:
109+
def __init__(self, dataset: Dataset, batch_size: int = 64, num_workers: int = 8, num_threads=256):
110+
self.dataset = dataset
111+
self.num_workers = num_workers
112+
self.batch_size = batch_size
113+
loader = partial(parallel_image_loader, num_threads=num_threads)
114+
self.thread = threading.Thread(target=loader, args=(dataset, num_workers))
115+
self.thread.start()
116+
117+
def __iter__(self):
118+
return self
119+
120+
def __next__(self):
121+
def fetcher(_):
122+
return data_queue.get()
123+
with ThreadPoolExecutor(max_workers=self.batch_size) as executor:
124+
batch = list(executor.map(fetcher, range(self.batch_size)))
125+
return batch
126+
127+
def __del__(self):
128+
self.thread.join()
129+
130+
def __len__(self):
131+
return len(self.dataset) // self.batch_size
132+
133+
def default_collate(batch):
134+
urls = [sample["url"] for sample in batch]
135+
captions = [sample["caption"] for sample in batch]
136+
images = np.stack([sample["image"] for sample in batch], axis=0)
137+
return {
138+
"url": urls,
139+
"caption": captions,
140+
"image": images,
141+
}
142+
143+
def dataMapper(map: Dict[str, Any]):
144+
def _map(sample) -> Dict[str, Any]:
145+
return {
146+
"url": sample[map["url"]],
147+
"caption": sample[map["caption"]],
148+
}
149+
return _map
150+
151+
class OnlineStreamingDataLoader():
152+
def __init__(
153+
self,
154+
dataset,
155+
batch_size=64,
156+
num_workers=16,
157+
num_threads=512,
158+
default_split="all",
159+
pre_map_maker=dataMapper,
160+
pre_map_def={
161+
"url": "URL",
162+
"caption": "TEXT",
163+
},
164+
global_process_count=1,
165+
global_process_index=0,
166+
prefetch=1000,
167+
collate_fn=default_collate,
168+
):
169+
if isinstance(dataset, str):
170+
dataset_path = dataset
171+
print("Loading dataset from path")
172+
dataset = load_dataset(dataset_path, split=default_split)
173+
elif isinstance(dataset, list):
174+
if isinstance(dataset[0], str):
175+
print("Loading multiple datasets from paths")
176+
dataset = [load_dataset(dataset_path, split=default_split) for dataset_path in dataset]
177+
else:
178+
print("Concatenating multiple datasets")
179+
dataset = concatenate_datasets(dataset)
180+
dataset = dataset.map(pre_map_maker(pre_map_def))
181+
self.dataset = dataset.shard(num_shards=global_process_count, index=global_process_index)
182+
print(f"Dataset length: {len(dataset)}")
183+
self.iterator = ImageBatchIterator(self.dataset, num_workers=num_workers, batch_size=batch_size, num_threads=num_threads)
184+
self.collate_fn = collate_fn
185+
186+
# Launch a thread to load batches in the background
187+
self.batch_queue = queue.Queue(prefetch)
188+
189+
def batch_loader():
190+
for batch in self.iterator:
191+
self.batch_queue.put(batch)
192+
193+
self.loader_thread = threading.Thread(target=batch_loader)
194+
self.loader_thread.start()
195+
196+
def __iter__(self):
197+
return self
198+
199+
def __next__(self):
200+
return self.collate_fn(self.batch_queue.get())
201+
# return self.collate_fn(next(self.iterator))
202+
203+
def __len__(self):
204+
return len(self.dataset) // self.batch_size
205+

flaxdiff/models/autoencoder/diffusers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,15 @@
1111
"""
1212

1313
class StableDiffusionVAE(AutoEncoder):
14-
def __init__(self, modelname = "CompVis/stable-diffusion-v1-4"):
14+
def __init__(self, modelname = "CompVis/stable-diffusion-v1-4", revision="bf16", dtype=jnp.bfloat16):
1515

1616
from diffusers.models.vae_flax import FlaxEncoder, FlaxDecoder
1717
from diffusers import FlaxStableDiffusionPipeline
1818

1919
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
2020
modelname,
21-
revision="bf16",
22-
dtype=jnp.bfloat16,
21+
revision=revision,
22+
dtype=dtype,
2323
)
2424

2525
vae = pipeline.vae

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
setup(
1212
name='flaxdiff',
1313
packages=find_packages(),
14-
version='0.1.12',
14+
version='0.1.13',
1515
description='A versatile and easy to understand Diffusion library',
1616
long_description=open('README.md').read(),
1717
long_description_content_type='text/markdown',

0 commit comments

Comments
 (0)