Skip to content

Commit 463691b

Browse files
dnnanutiIsaevIlya
andcommitted
Add implementation of readinto to decrease amount of copy operations (#200)
Co-authored-by: Ilya Isaev <[email protected]>
1 parent 0b3b63a commit 463691b

File tree

4 files changed

+146
-4
lines changed

4 files changed

+146
-4
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
### New features
44
* Update S3ClientConfig to pass in the configuration for allowing unsigned requests, under boolean flag `unsigned`.
5+
* Improve the performance of `s3reader` when utilized with `pytorch.load` by incorporating support for the `readinto` method.
56

67

78
## v1.2.2 (March 22, 2024)

s3torchconnector/src/s3torchconnector/s3reader.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,34 @@ def prefetch(self) -> None:
5353
if self._stream is None:
5454
self._stream = self._get_stream()
5555

56+
def readinto(self, buf) -> int:
57+
"""Read up to len(buf) bytes into a pre-allocated, writable bytes-like object buf.
58+
Return the number of bytes read. If no bytes are available, zero is returned.
59+
60+
Args:
61+
buf : writable bytes-like object
62+
63+
Returns:
64+
int : numer of bytes read or zero, if no bytes available
65+
"""
66+
buf_size = len(buf)
67+
if self._position_at_end() or buf_size == 0:
68+
# If no bytes are available or no place to write data, zero should be returned
69+
return 0
70+
71+
self.prefetch()
72+
assert self._stream is not None
73+
74+
cur_pos = self._position
75+
# preload enough bytes in buffer
76+
self.seek(buf_size, SEEK_CUR)
77+
# restore position, before starting to write into buf
78+
self._buffer.seek(cur_pos)
79+
size = self._buffer.readinto(buf)
80+
self._position = self._buffer.tell()
81+
82+
return size
83+
5684
def read(self, size: Optional[int] = None) -> bytes:
5785
"""Read up to size bytes from the object and return them.
5886
@@ -82,7 +110,9 @@ def read(self, size: Optional[int] = None) -> bytes:
82110
if size is None or size < 0:
83111
# Special case read() all to use O(n) algorithm
84112
self._buffer.seek(0, SEEK_END)
85-
self._buffer.write(b"".join(self._stream))
113+
for batch in self._stream:
114+
self._buffer.write(batch)
115+
86116
# Once we've emptied the buffer, we'll always be at EOF!
87117
self._size = self._buffer.tell()
88118
else:

s3torchconnector/tst/e2e/test_e2e_s3checkpoint.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,18 @@
22
# // SPDX-License-Identifier: BSD
33

44
import torch
5+
import pytest
56

67
from s3torchconnector import S3Checkpoint
78
from models.net import Net
89

910

10-
def test_general_checkpointing(checkpoint_directory):
11-
tensor = torch.tensor([[0.1, 1.2], [2.2, 3.1], [4.9, 5.2]])
11+
@pytest.mark.parametrize(
12+
"tensor_dimensions",
13+
[[3, 2], [10, 1024, 1024]],
14+
)
15+
def test_general_checkpointing(checkpoint_directory, tensor_dimensions):
16+
tensor = torch.rand(tensor_dimensions)
1217
checkpoint_name = "general_checkpoint.pt"
1318
checkpoint = S3Checkpoint(region=checkpoint_directory.region)
1419
s3_uri = f"{checkpoint_directory.s3_uri}/{checkpoint_name}"

s3torchconnector/tst/unit/test_s3reader.py

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121

2222
log = logging.getLogger(__name__)
2323

24-
2524
TEST_BUCKET = "test-bucket"
2625
TEST_KEY = "test-key"
2726
MOCK_OBJECT_INFO = Mock(ObjectInfo)
@@ -182,14 +181,17 @@ def test_over_read(stream: List[bytes], overread: int):
182181
def test_seeks_end():
183182
s3reader = S3Reader(TEST_BUCKET, TEST_KEY, lambda: None, lambda: iter([]))
184183
s3reader._size = 10
184+
buf = memoryview(bytearray(10))
185185

186186
assert s3reader.seek(0, SEEK_END) == 10
187187
assert s3reader.tell() == 10
188188
assert s3reader.read() == b""
189+
assert s3reader.readinto(buf) == 0
189190

190191
assert s3reader.seek(0, SEEK_CUR) == 10
191192
assert s3reader.tell() == 10
192193
assert s3reader.read() == b""
194+
assert s3reader.readinto(buf) == 0
193195

194196

195197
def test_not_writable():
@@ -301,3 +303,107 @@ def test_s3reader_writes_size_after_read_all_explicit(stream: List[bytes]):
301303
assert s3reader.read(1) == b""
302304
# Once we've read past the end, we know how big the file is
303305
assert s3reader._size == total_length
306+
307+
308+
@given(
309+
lists(binary(min_size=20, max_size=30), min_size=0, max_size=2),
310+
integers(min_value=0, max_value=10),
311+
)
312+
def test_s3reader_readinto_buffer_smaller_than_chunks(
313+
stream: List[bytes], buf_size: int
314+
):
315+
s3reader = S3Reader(TEST_BUCKET, TEST_KEY, lambda: None, lambda: iter(stream))
316+
assert s3reader._size is None
317+
total_length = sum(map(len, stream))
318+
buf = memoryview(bytearray(buf_size))
319+
# We're able to read all the available data or the data that can be accommodated in buf
320+
if buf_size > 0 and total_length > 0:
321+
assert s3reader.readinto(buf) == buf_size
322+
assert s3reader.tell() == buf_size
323+
# We haven't reached the end yet
324+
assert s3reader._size is None
325+
# confirm that read data is the same as in source
326+
assert buf[:buf_size] == (b"".join(stream))[:buf_size]
327+
else:
328+
assert s3reader.readinto(buf) == 0
329+
assert s3reader.tell() == 0
330+
331+
332+
@given(
333+
lists(binary(min_size=20, max_size=30), min_size=2, max_size=3),
334+
integers(min_value=30, max_value=40),
335+
)
336+
def test_s3reader_readinto_buffer_bigger_than_chunks(
337+
stream: List[bytes], buf_size: int
338+
):
339+
s3reader = S3Reader(TEST_BUCKET, TEST_KEY, lambda: None, lambda: iter(stream))
340+
assert s3reader._size is None
341+
buf = memoryview(bytearray(buf_size))
342+
# We're able to read the data that can be accommodated in buf
343+
assert s3reader.readinto(buf) == buf_size
344+
assert s3reader.tell() == buf_size
345+
all_data = b"".join(stream)
346+
# confirm that read data is the same as in source
347+
assert buf == all_data[:buf_size]
348+
349+
350+
@given(
351+
lists(binary(min_size=20, max_size=30), min_size=1, max_size=3),
352+
integers(min_value=100, max_value=100),
353+
)
354+
def test_s3reader_readinto_buffer_bigger_than_whole_object(
355+
stream: List[bytes], buf_size: int
356+
):
357+
s3reader = S3Reader(TEST_BUCKET, TEST_KEY, lambda: None, lambda: iter(stream))
358+
assert s3reader._size is None
359+
total_length = sum(map(len, stream))
360+
buf = memoryview(bytearray(buf_size))
361+
# We're able to read all the available data
362+
assert s3reader.readinto(buf) == total_length
363+
assert s3reader.tell() == total_length
364+
all_data = b"".join(stream)
365+
# confirm that read data is the same as in source
366+
assert buf[:total_length] == all_data
367+
assert s3reader._size == total_length
368+
369+
370+
@given(
371+
lists(binary(min_size=2, max_size=12), min_size=1, max_size=5),
372+
integers(min_value=3, max_value=10),
373+
integers(min_value=0, max_value=1),
374+
)
375+
def test_s3reader_mixing_readinto_and_read(
376+
stream: List[bytes], buf_size: int, flip: int
377+
):
378+
position = 0
379+
loops_count = 20
380+
all_data = b"".join(stream)
381+
total_length = len(all_data)
382+
buf = memoryview(bytearray(buf_size))
383+
s3reader = S3Reader(TEST_BUCKET, TEST_KEY, lambda: None, lambda: iter(stream))
384+
for i in range(0, loops_count):
385+
if position >= total_length:
386+
break
387+
388+
if (i + flip) % 2 == 0:
389+
result = s3reader.read(buf_size)
390+
# confirm that read data is the same as in source
391+
if position + buf_size < total_length:
392+
assert result[:buf_size] == all_data[position : position + buf_size]
393+
else:
394+
read_bytes = total_length - position
395+
assert result[:read_bytes] == all_data[position:total_length]
396+
position += buf_size
397+
else:
398+
read_bytes = s3reader.readinto(buf)
399+
# confirm that read data is the same as in source
400+
assert buf[position:read_bytes] == all_data[position:read_bytes]
401+
position += read_bytes
402+
403+
if position > total_length:
404+
# we read all the data, it is time to stop
405+
assert s3reader.tell() == total_length
406+
break
407+
else:
408+
# confirm that position is as expected
409+
assert s3reader.tell() == position

0 commit comments

Comments
 (0)