|
7 | 7 | """
|
8 | 8 |
|
9 | 9 | import h5py
|
10 |
| -from swmr_tools.KeyFollower import Follower, FrameGrabber |
| 10 | +from swmr_tools.KeyFollower import Follower |
11 | 11 | import numpy as np
|
12 |
| -complete_keys = np.arange(25).reshape(5,5,1,1) + 1 |
13 |
| -complete_dataset = np.random.randint(low = 1, high = 5000, size = (5,5,10,20)) |
14 |
| - |
15 |
| -with h5py.File("test.h5", "w", libver = "latest") as f: |
16 |
| - f.create_group('keys') |
17 |
| - f.create_group('data') |
18 |
| - f['keys'].create_dataset("key_1", data = complete_keys) |
19 |
| - f['data'].create_dataset("data_1", data = complete_dataset) |
20 |
| - |
21 |
| - |
22 |
| -import time |
23 |
| - |
24 |
| -def long_function(key, filepath = "test.h5", dataset = "data/data_1"): |
25 |
| - time.sleep(1) |
26 |
| - print(f"Starting key {key}") |
27 |
| - with h5py.File(filepath, "r", swmr = True) as f: |
28 |
| - fg = FrameGrabber(dataset, f) |
29 |
| - frame = fg.Grabber(key) |
30 |
| - print(f"getting frame sum {frame.sum()}") |
31 |
| - return frame.sum() |
32 |
| - |
33 |
| -def key_generator(queue, filepath = "test.h5"): |
34 |
| - with h5py.File(filepath, "r", swmr = True) as f: |
35 |
| - kf = Follower(f, ['keys'], timeout = 0.1) |
36 |
| - for key in kf: |
37 |
| - queue.put(key) |
38 |
| - |
39 |
| - queue.put("End") |
40 |
| - |
41 |
| -def frame_consumer_serial(queue, filepath = "test.h5", dataset = "data/data_1"): |
42 |
| - return_list = [] |
43 |
| - key = queue.get() |
44 |
| - while key != 'End': |
45 |
| - #print(key) |
46 |
| - return_list.append(long_function(key)) |
47 |
| - key = queue.get() |
48 |
| - #print("Done") |
49 |
| - return return_list |
50 |
| - |
51 |
| - |
52 |
| -from dask.distributed import Client |
53 |
| - |
54 |
| -def frame_consumer_parallel(queue, filepath = "test.h5", dataset = "data/data_1"): |
55 |
| - return_list = [] |
56 |
| - client = Client() |
57 |
| - key = queue.get() |
58 |
| - while key != 'End': |
59 |
| - return_list.append(client.submit(long_function, key)) |
60 |
| - key = queue.get() |
61 |
| - return client.gather(return_list) |
62 |
| - |
63 |
| -import time |
64 |
| -from threading import Thread |
65 | 12 | from queue import Queue
|
| 13 | +from threading import Thread |
66 | 14 |
|
67 |
| -def main_1(): |
68 |
| - queue = Queue() |
69 |
| - key_generator_thread = Thread(target = key_generator(queue)) |
70 |
| - frame_consumer_serial_thread = Thread(target = frame_consumer_serial, args = (queue,)) |
71 |
| - |
72 |
| - start_time = time.time() |
73 |
| - key_generator_thread.start() |
74 |
| - frame_consumer_serial_thread.start() |
75 |
| - key_generator_thread.join() |
76 |
| - frame_consumer_serial_thread.join() |
77 |
| - finish_time = time.time() |
78 |
| - print(f"serial_time_taken = {finish_time - start_time}") |
79 |
| - |
80 |
| -#main_1() |
81 |
| - |
82 |
| -def main_2(): |
83 |
| - queue = Queue() |
84 |
| - key_generator_thread = Thread(target = key_generator, args = (queue,)) |
85 |
| - frame_consumer_serial_thread = Thread(target = frame_consumer_parallel, args = (queue,)) |
86 |
| - |
87 |
| - start_time = time.time() |
88 |
| - key_generator_thread.start() |
89 |
| - frame_consumer_serial_thread.start() |
90 |
| - key_generator_thread.join() |
91 |
| - frame_consumer_serial_thread.join() |
92 |
| - finish_time = time.time() |
93 |
| - print(f"serial_time_taken = {finish_time - start_time}") |
94 | 15 |
|
95 |
| -#main_2() |
| 16 | +class BinOp(): |
| 17 | + |
| 18 | + def __init__(self, |
| 19 | + hdf5_file, |
| 20 | + key_datasets, |
| 21 | + dataset_1, |
| 22 | + dataset_2): |
| 23 | + |
| 24 | + self.hdf5_file = hdf5_file |
| 25 | + self.key_datasets = key_datasets |
| 26 | + self.dataset_1 = dataset_1 |
| 27 | + self.dataset_2 = dataset_2 |
| 28 | + self.queue = Queue() |
| 29 | + |
| 30 | + |
| 31 | + def _key_generator(self): |
| 32 | + with h5py.File(self.hdf5_file, "r", swmr = True) as f: |
| 33 | + kf = Follower(f, self.key_datasets, timeout = 0.1) |
| 34 | + for key in kf: |
| 35 | + self.queue.put(key) |
| 36 | + self.queue.put("End") |
| 37 | + |
| 38 | + def _frame_consumer(self, fn): |
| 39 | + return_list = [] |
| 40 | + key = self.queue.get() |
| 41 | + while key != 'End': |
| 42 | + return_list.append(fn(key)) |
| 43 | + key = self.queue.get() |
| 44 | + return return_list |
| 45 | + |
| 46 | + |
| 47 | + def run(self, fn): |
| 48 | + queue = Queue() |
96 | 49 |
|
97 | 50 |
|
98 | 51 |
|
0 commit comments