Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 31 additions & 9 deletions python/paddle/text/datasets/imdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import collections
import re
import string
import tarfile
from typing import TYPE_CHECKING, Literal

import numpy as np

from paddle.dataset.common import _check_exists_and_download
from paddle.io import Dataset

if TYPE_CHECKING:
from re import Pattern

import numpy.typing as npt

_ImdbDataSetMode = Literal["train", "test"]
__all__ = []

URL = 'https://dataset.bj.bcebos.com/imdb%2FaclImdb_v1.tar.gz'
Expand All @@ -33,12 +41,12 @@ class Imdb(Dataset):
Implementation of `IMDB <https://www.imdb.com/interfaces/>`_ dataset.

Args:
data_file(str): path to data tar file, can be set None if
:attr:`download` is True. Default None
data_file(str|None): path to data tar file, can be set None if
:attr:`download` is True. Default None.
mode(str): 'train' 'test' mode. Default 'train'.
cutoff(int): cutoff number for building word dictionary. Default 150.
download(bool): whether to download dataset automatically if
:attr:`data_file` is not set. Default True
:attr:`data_file` is not set. Default True.

Returns:
Dataset: instance of IMDB dataset
Expand Down Expand Up @@ -82,7 +90,19 @@ class Imdb(Dataset):

"""

def __init__(self, data_file=None, mode='train', cutoff=150, download=True):
data_file: str | None
mode: _ImdbDataSetMode
word_idx: dict[str, int]
docs: list
labels: list

def __init__(
self,
data_file: str | None = None,
mode: _ImdbDataSetMode = 'train',
cutoff: int = 150,
download: bool = True,
) -> None:
assert mode.lower() in [
'train',
'test',
Expand All @@ -104,7 +124,7 @@ def __init__(self, data_file=None, mode='train', cutoff=150, download=True):
# read dataset into memory
self._load_anno()

def _build_work_dict(self, cutoff):
def _build_work_dict(self, cutoff: int) -> dict[str, int]:
word_freq = collections.defaultdict(int)
pattern = re.compile(r"aclImdb/((train)|(test))/((pos)|(neg))/.*\.txt$")
for doc in self._tokenize(pattern):
Expand All @@ -120,7 +140,7 @@ def _build_work_dict(self, cutoff):
word_idx['<unk>'] = len(words)
return word_idx

def _tokenize(self, pattern):
def _tokenize(self, pattern: Pattern[str]) -> list[list[str]]:
data = []
with tarfile.open(self.data_file) as tarf:
tf = tarf.next()
Expand All @@ -139,7 +159,7 @@ def _tokenize(self, pattern):

return data

def _load_anno(self):
def _load_anno(self) -> None:
pos_pattern = re.compile(fr"aclImdb/{self.mode}/pos/.*\.txt$")
neg_pattern = re.compile(fr"aclImdb/{self.mode}/neg/.*\.txt$")

Expand All @@ -154,8 +174,10 @@ def _load_anno(self):
self.docs.append([self.word_idx.get(w, UNK) for w in doc])
self.labels.append(1)

def __getitem__(self, idx):
def __getitem__(
self, idx: int
) -> tuple[npt.NDArray[np.int_], npt.NDArray[np.int_]]:
return (np.array(self.docs[idx]), np.array([self.labels[idx]]))

def __len__(self):
def __len__(self) -> int:
return len(self.docs)