Skip to content

Commit a47c97b

Browse files
enkileeSigureMo
authored andcommitted
[Typing][B-61] Add type annotations for python/paddle/text/datasets/imdb.py (PaddlePaddle#66037)
--------- Co-authored-by: Nyakku Shigure <[email protected]>
1 parent b820224 commit a47c97b

File tree

1 file changed

+31
-9
lines changed

1 file changed

+31
-9
lines changed

python/paddle/text/datasets/imdb.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,25 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from __future__ import annotations
1415

1516
import collections
1617
import re
1718
import string
1819
import tarfile
20+
from typing import TYPE_CHECKING, Literal
1921

2022
import numpy as np
2123

2224
from paddle.dataset.common import _check_exists_and_download
2325
from paddle.io import Dataset
2426

27+
if TYPE_CHECKING:
28+
from re import Pattern
29+
30+
import numpy.typing as npt
31+
32+
_ImdbDataSetMode = Literal["train", "test"]
2533
__all__ = []
2634

2735
URL = 'https://dataset.bj.bcebos.com/imdb%2FaclImdb_v1.tar.gz'
@@ -33,12 +41,12 @@ class Imdb(Dataset):
3341
Implementation of `IMDB <https://www.imdb.com/interfaces/>`_ dataset.
3442
3543
Args:
36-
data_file(str): path to data tar file, can be set None if
37-
:attr:`download` is True. Default None
44+
data_file(str|None): path to data tar file, can be set None if
45+
:attr:`download` is True. Default None.
3846
mode(str): 'train' 'test' mode. Default 'train'.
3947
cutoff(int): cutoff number for building word dictionary. Default 150.
4048
download(bool): whether to download dataset automatically if
41-
:attr:`data_file` is not set. Default True
49+
:attr:`data_file` is not set. Default True.
4250
4351
Returns:
4452
Dataset: instance of IMDB dataset
@@ -82,7 +90,19 @@ class Imdb(Dataset):
8290
8391
"""
8492

85-
def __init__(self, data_file=None, mode='train', cutoff=150, download=True):
93+
data_file: str | None
94+
mode: _ImdbDataSetMode
95+
word_idx: dict[str, int]
96+
docs: list
97+
labels: list
98+
99+
def __init__(
100+
self,
101+
data_file: str | None = None,
102+
mode: _ImdbDataSetMode = 'train',
103+
cutoff: int = 150,
104+
download: bool = True,
105+
) -> None:
86106
assert mode.lower() in [
87107
'train',
88108
'test',
@@ -104,7 +124,7 @@ def __init__(self, data_file=None, mode='train', cutoff=150, download=True):
104124
# read dataset into memory
105125
self._load_anno()
106126

107-
def _build_work_dict(self, cutoff):
127+
def _build_work_dict(self, cutoff: int) -> dict[str, int]:
108128
word_freq = collections.defaultdict(int)
109129
pattern = re.compile(r"aclImdb/((train)|(test))/((pos)|(neg))/.*\.txt$")
110130
for doc in self._tokenize(pattern):
@@ -120,7 +140,7 @@ def _build_work_dict(self, cutoff):
120140
word_idx['<unk>'] = len(words)
121141
return word_idx
122142

123-
def _tokenize(self, pattern):
143+
def _tokenize(self, pattern: Pattern[str]) -> list[list[str]]:
124144
data = []
125145
with tarfile.open(self.data_file) as tarf:
126146
tf = tarf.next()
@@ -139,7 +159,7 @@ def _tokenize(self, pattern):
139159

140160
return data
141161

142-
def _load_anno(self):
162+
def _load_anno(self) -> None:
143163
pos_pattern = re.compile(fr"aclImdb/{self.mode}/pos/.*\.txt$")
144164
neg_pattern = re.compile(fr"aclImdb/{self.mode}/neg/.*\.txt$")
145165

@@ -154,8 +174,10 @@ def _load_anno(self):
154174
self.docs.append([self.word_idx.get(w, UNK) for w in doc])
155175
self.labels.append(1)
156176

157-
def __getitem__(self, idx):
177+
def __getitem__(
178+
self, idx: int
179+
) -> tuple[npt.NDArray[np.int_], npt.NDArray[np.int_]]:
158180
return (np.array(self.docs[idx]), np.array([self.labels[idx]]))
159181

160-
def __len__(self):
182+
def __len__(self) -> int:
161183
return len(self.docs)

0 commit comments

Comments
 (0)