1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+ from __future__ import annotations
1415
1516import collections
1617import re
1718import string
1819import tarfile
20+ from typing import TYPE_CHECKING , Literal
1921
2022import numpy as np
2123
2224from paddle .dataset .common import _check_exists_and_download
2325from paddle .io import Dataset
2426
27+ if TYPE_CHECKING :
28+ from re import Pattern
29+
30+ import numpy .typing as npt
31+
32+ _ImdbDataSetMode = Literal ["train" , "test" ]
2533__all__ = []
2634
2735URL = 'https://dataset.bj.bcebos.com/imdb%2FaclImdb_v1.tar.gz'
@@ -33,12 +41,12 @@ class Imdb(Dataset):
3341 Implementation of `IMDB <https://www.imdb.com/interfaces/>`_ dataset.
3442
3543 Args:
36- data_file(str): path to data tar file, can be set None if
37- :attr:`download` is True. Default None
44+ data_file(str|None ): path to data tar file, can be set None if
45+ :attr:`download` is True. Default None.
3846 mode(str): 'train' 'test' mode. Default 'train'.
3947 cutoff(int): cutoff number for building word dictionary. Default 150.
4048 download(bool): whether to download dataset automatically if
41- :attr:`data_file` is not set. Default True
49+ :attr:`data_file` is not set. Default True.
4250
4351 Returns:
4452 Dataset: instance of IMDB dataset
@@ -82,7 +90,19 @@ class Imdb(Dataset):
8290
8391 """
8492
85- def __init__ (self , data_file = None , mode = 'train' , cutoff = 150 , download = True ):
93+ data_file : str | None
94+ mode : _ImdbDataSetMode
95+ word_idx : dict [str , int ]
96+ docs : list
97+ labels : list
98+
99+ def __init__ (
100+ self ,
101+ data_file : str | None = None ,
102+ mode : _ImdbDataSetMode = 'train' ,
103+ cutoff : int = 150 ,
104+ download : bool = True ,
105+ ) -> None :
86106 assert mode .lower () in [
87107 'train' ,
88108 'test' ,
@@ -104,7 +124,7 @@ def __init__(self, data_file=None, mode='train', cutoff=150, download=True):
104124 # read dataset into memory
105125 self ._load_anno ()
106126
107- def _build_work_dict (self , cutoff ) :
127+ def _build_work_dict (self , cutoff : int ) -> dict [ str , int ] :
108128 word_freq = collections .defaultdict (int )
109129 pattern = re .compile (r"aclImdb/((train)|(test))/((pos)|(neg))/.*\.txt$" )
110130 for doc in self ._tokenize (pattern ):
@@ -120,7 +140,7 @@ def _build_work_dict(self, cutoff):
120140 word_idx ['<unk>' ] = len (words )
121141 return word_idx
122142
123- def _tokenize (self , pattern ) :
143+ def _tokenize (self , pattern : Pattern [ str ]) -> list [ list [ str ]] :
124144 data = []
125145 with tarfile .open (self .data_file ) as tarf :
126146 tf = tarf .next ()
@@ -139,7 +159,7 @@ def _tokenize(self, pattern):
139159
140160 return data
141161
142- def _load_anno (self ):
162+ def _load_anno (self ) -> None :
143163 pos_pattern = re .compile (fr"aclImdb/{ self .mode } /pos/.*\.txt$" )
144164 neg_pattern = re .compile (fr"aclImdb/{ self .mode } /neg/.*\.txt$" )
145165
@@ -154,8 +174,10 @@ def _load_anno(self):
154174 self .docs .append ([self .word_idx .get (w , UNK ) for w in doc ])
155175 self .labels .append (1 )
156176
157- def __getitem__ (self , idx ):
177+ def __getitem__ (
178+ self , idx : int
179+ ) -> tuple [npt .NDArray [np .int_ ], npt .NDArray [np .int_ ]]:
158180 return (np .array (self .docs [idx ]), np .array ([self .labels [idx ]]))
159181
160- def __len__ (self ):
182+ def __len__ (self ) -> int :
161183 return len (self .docs )
0 commit comments