Skip to content

Commit daef051

Browse files
committed
fix: resolve mypy errors and minor logic and name changes
- type annotations, missing arguments / return statements, etc - minor logic/name changes in list_of_ranges.py - very minor change to fix mypy error on test_user_agent.py
1 parent efeb194 commit daef051

File tree

4 files changed

+29
-35
lines changed

4 files changed

+29
-35
lines changed

s3torchconnector/src/s3torchconnector/dcp/s3_file_system.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,11 @@ def __init__(
354354
e.g. S3ReaderConstructor.sequential() or S3ReaderConstructor.range_based()
355355
"""
356356
super().__init__(path)
357-
self.fs = S3FileSystem(region, s3client_config=s3client_config, reader_constructor=reader_constructor) # type: ignore
357+
self.fs: S3FileSystem = S3FileSystem( # type: ignore[assignment]
358+
region,
359+
s3client_config=s3client_config,
360+
reader_constructor=reader_constructor,
361+
)
358362
self.path = self.fs.init_path(path)
359363
self.sync_files = False
360364

@@ -376,7 +380,7 @@ def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
376380
# Inject ranges if using DCP list-of-ranges reader constructor
377381
if isinstance(self.fs._reader_constructor, DCPListOfRangesConstructor):
378382
# Calculate ranges per file
379-
per_file_ranges = {}
383+
per_file_ranges: Dict[str, List[RangeRequest]] = {}
380384
for read_item in plan.items:
381385
item_md = self.storage_data[read_item.storage_index]
382386
path = item_md.relative_path
@@ -391,6 +395,9 @@ def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
391395

392396
# Sort items in plan based on their offset in checkpoints shards
393397
plan.items.sort(key=lambda item: self.storage_data[item.storage_index].offset)
398+
logger.info(
399+
f"Sorted {len(plan.items)} items in load plan based on offset in checkpoint shards"
400+
)
394401
return plan
395402

396403

s3torchconnector/src/s3torchconnector/s3reader/list_of_ranges.py

Lines changed: 18 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
# // SPDX-License-Identifier: BSD
33

4-
import os
54
import logging
65
from dataclasses import dataclass
76
from typing import List, Optional, Callable, Union, Dict
@@ -53,8 +52,7 @@ def __init__(
5352
# Calculate range groups using coalescing logic
5453
self._range_groups = self._calculate_range_groups(ranges, max_gap_size)
5554

56-
# Pre-create all readers and prefetch immediately
57-
# TODO - judge if this is beneficial or not.
55+
# Pre-create all readers
5856
self._group_readers: Dict[int, SequentialS3Reader] = {}
5957
for i, group in enumerate(self._range_groups):
6058
reader = SequentialS3Reader(
@@ -65,16 +63,11 @@ def __init__(
6563
start_offset=group.start,
6664
end_offset=group.end,
6765
)
66+
# TODO - judge if this is beneficial or not.
6867
reader.prefetch() # Batch prefetch all ranges
6968
self._group_readers[i] = reader
7069

71-
# Pre-calculate request-to-reader mapping
72-
self._request_to_reader: Dict[int, int] = {}
73-
for i, group in enumerate(self._range_groups):
74-
for request in group.requests:
75-
self._request_to_reader[request.start] = i
76-
77-
self._current_position = 0
70+
self._position: int = 0
7871

7972
@property
8073
def bucket(self) -> str:
@@ -92,6 +85,7 @@ def _calculate_range_groups(
9285
if not ranges:
9386
return []
9487

88+
# TODO: could be pre-sorted in prepare_local_plan for dcp.load
9589
sorted_ranges = sorted(ranges, key=lambda r: r.start)
9690
groups = []
9791
current_group = [sorted_ranges[0]]
@@ -117,48 +111,41 @@ def _create_range_group(self, ranges: List[RangeRequest]) -> RangeGroup:
117111
group_end = max(r.end for r in ranges)
118112
return RangeGroup(start=group_start, end=group_end, requests=ranges)
119113

120-
def get_reader_for_request(
121-
self, request_start: int
122-
) -> Optional[SequentialS3Reader]:
123-
"""O(1) lookup using pre-calculated mapping."""
124-
reader_idx = self._request_to_reader.get(request_start)
125-
return self._group_readers.get(reader_idx) if reader_idx is not None else None
126-
127114
def _find_reader_for_offset(self, offset: int) -> Optional[SequentialS3Reader]:
128115
"""Find reader that contains the given offset."""
129-
# TODO: improve logic using binary search
130-
for reader in self._group_readers.values():
131-
if reader._start_offset <= offset < reader._end_offset:
132-
return reader
133-
elif reader._start_offset > offset:
134-
break # Early termination since readers are ordered
116+
for i, group in enumerate(self._range_groups):
117+
if group.start <= offset < group.end:
118+
self._current_reader_index = i
119+
return self._group_readers[i]
120+
if group.start > offset: # TODO handle this case properly by raising errors
121+
break
135122
return None
136123

137124
def seek(self, offset: int, whence: int = SEEK_SET, /) -> int:
138-
self._current_position = offset
125+
self._position = offset
139126
reader = self._find_reader_for_offset(offset)
140127
if not reader:
141-
return self._current_position
142-
reader.seek(offset, whence)
128+
return self._position
129+
return reader.seek(offset, whence)
143130

144131
def read(self, size: Optional[int] = None) -> bytes:
145-
reader = self._find_reader_for_offset(self._current_position)
132+
reader = self._find_reader_for_offset(self._position)
146133
if not reader:
147134
return b""
148135
data = reader.read(size)
149-
self._current_position += len(data)
136+
self._position += len(data)
150137
return data
151138

152139
def readinto(self, buf) -> int:
153-
reader = self._find_reader_for_offset(self._current_position)
140+
reader = self._find_reader_for_offset(self._position)
154141
if not reader:
155142
return 0
156143
bytes_read = reader.readinto(buf)
157-
self._current_position += bytes_read
144+
self._position += bytes_read
158145
return bytes_read
159146

160147
def tell(self) -> int:
161-
return self._current_position
148+
return self._position
162149

163150
def close(self) -> None:
164151
for reader in self._group_readers.values():

s3torchconnector/src/s3torchconnector/s3reader/sequential.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def prefetch(self) -> None:
107107
if self._start_offset is not None or self._end_offset is not None:
108108
self._stream = self._get_stream(self._start_offset, self._end_offset)
109109
else:
110-
self._stream = self._get_stream()
110+
self._stream = self._get_stream(None, None)
111111

112112
def readinto(self, buf) -> int:
113113
"""Read up to len(buf) bytes into a pre-allocated, writable bytes-like object buf.

s3torchconnector/tst/unit/test_user_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,5 +39,5 @@ def test_default_user_agent_creation():
3939

4040
@pytest.mark.parametrize("invalid_comment", [0, "string"])
4141
def test_invalid_comments_argument(invalid_comment):
42-
with pytest.raises(ValueError, match="Argument comments must be a List\[str\]"):
42+
with pytest.raises(ValueError, match=r"Argument comments must be a List\[str\]"):
4343
UserAgent(invalid_comment)

0 commit comments

Comments
 (0)