Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion python/paddle/_typing/device_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,16 @@
XPUPlace,
)

PlaceLike: TypeAlias = Union[
_Place: TypeAlias = Union[
"CPUPlace",
"CUDAPlace",
"CUDAPinnedPlace",
"IPUPlace",
"CustomPlace",
"XPUPlace",
]

PlaceLike: TypeAlias = Union[
_Place,
str, # some string like "cpu", "gpu:0", etc.
]
3 changes: 3 additions & 0 deletions python/paddle/io/dataloader/dataloader_iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ def _index_sampler(self):
def __iter__(self):
return self

def __next__(self):
raise NotImplementedError('Should implement `__next__` for a iterator')

def __len__(self):
return len(self._batch_sampler)

Expand Down
8 changes: 4 additions & 4 deletions python/paddle/io/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@

from paddle import Tensor
from paddle._typing import PlaceLike
from paddle._typing.device_like import _Place
from paddle.io.dataloader.dataloader_iter import _DataLoaderIterBase
from paddle.io.dataloader.dataset import Dataset

from .dataloader.dataloader_iter import _DataLoaderIterBase

_K = TypeVar('_K')
_V = TypeVar('_V')

Expand Down Expand Up @@ -440,7 +440,7 @@ class DataLoader:
worker_init_fn: Callable[[int], None]
dataset: Dataset
feed_list: Sequence[Tensor] | None
places: Sequence[PlaceLike] | None
places: list[_Place]
num_workers: int
dataset_kind: _DatasetKind
use_shared_memory: bool
Expand All @@ -449,7 +449,7 @@ def __init__(
self,
dataset: Dataset,
feed_list: Sequence[Tensor] | None = None,
places: Sequence[PlaceLike] | None = None,
places: PlaceLike | Sequence[PlaceLike] | None = None,
return_list: bool = True,
batch_sampler: BatchSampler | None = None,
batch_size: int = 1,
Expand Down