Skip to content
This repository was archived by the owner on Jan 15, 2024. It is now read-only.

Commit bb5fcdb

Browse files
committed
typehints for Vocab __init__
1 parent 0275d50 commit bb5fcdb

File tree

1 file changed

+19
-16
lines changed

1 file changed

+19
-16
lines changed

src/gluonnlp/vocab/vocab.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,21 +20,19 @@
2020
# pylint: disable=consider-iterating-dictionary
2121

2222
"""Vocabulary."""
23-
from __future__ import absolute_import
24-
from __future__ import print_function
25-
2623
__all__ = ['Vocab']
2724

2825
import collections
2926
import json
27+
import typing
3028
import uuid
3129
import warnings
3230

3331
from mxnet import nd
3432

35-
from ..data.utils import DefaultLookupDict, count_tokens
3633
from .. import _constants as C
3734
from .. import embedding as emb
35+
from ..data.utils import DefaultLookupDict, Counter, count_tokens
3836

3937
UNK_IDX = 0
4038

@@ -44,38 +42,38 @@ class Vocab(object):
4442
4543
Parameters
4644
----------
47-
counter : Counter or None, default None
45+
counter
4846
Counts text token frequencies in the text data. Its keys will be indexed according to
4947
frequency thresholds such as `max_size` and `min_freq`. Keys of `counter`,
5048
`unknown_token`, and values of `reserved_tokens` must be of the same hashable type.
5149
Examples: str, int, and tuple.
52-
max_size : None or int, default None
50+
max_size
5351
The maximum possible number of the most frequent tokens in the keys of `counter` that can be
5452
indexed. Note that this argument does not count any token from `reserved_tokens`. Suppose
5553
that there are different keys of `counter` whose frequency are the same, if indexing all of
5654
them will exceed this argument value, such keys will be indexed one by one according to
5755
their __cmp__() order until the frequency threshold is met. If this argument is None or
5856
larger than its largest possible value restricted by `counter` and `reserved_tokens`, this
5957
argument has no effect.
60-
min_freq : int, default 1
58+
min_freq
6159
The minimum frequency required for a token in the keys of `counter` to be indexed.
62-
unknown_token : hashable object or None, default '<unk>'
60+
unknown_token
6361
The representation for any unknown token. If `unknown_token` is not
6462
`None`, looking up any token that is not part of the vocabulary and
6563
thus considered unknown will return the index of `unknown_token`. If
6664
None, looking up an unknown token will result in `KeyError`.
67-
padding_token : hashable object or None, default '<pad>'
65+
padding_token
6866
The representation for the special token of padding token.
69-
bos_token : hashable object or None, default '<bos>'
67+
bos_token
7068
The representation for the special token of beginning-of-sequence token.
71-
eos_token : hashable object or None, default '<eos>'
69+
eos_token
7270
The representation for the special token of end-of-sequence token.
73-
reserved_tokens : list of hashable objects or None, default None
71+
reserved_tokens
7472
A list specifying additional tokens to be added to the vocabulary.
7573
`reserved_tokens` must not contain the value of `unknown_token` or
7674
duplicate tokens. It must neither contain special tokens specified via
7775
keyword arguments.
78-
token_to_idx : dict mapping tokens (hashable objects) to int or None, default None
76+
token_to_idx
7977
If not `None`, specifies the indices of tokens to be used by the
8078
vocabulary. Each token in `token_to_index` must be part of the Vocab
8179
and each index can only be associated with a single token.
@@ -175,9 +173,14 @@ class Vocab(object):
175173
176174
"""
177175

178-
def __init__(self, counter=None, max_size=None, min_freq=1, unknown_token=C.UNK_TOKEN,
179-
padding_token=C.PAD_TOKEN, bos_token=C.BOS_TOKEN, eos_token=C.EOS_TOKEN,
180-
reserved_tokens=None, token_to_idx=None, **kwargs):
176+
def __init__(self, counter: typing.Optional[Counter] = None,
177+
max_size: typing.Optional[int] = None, min_freq: int = 1,
178+
unknown_token: typing.Optional[typing.Hashable] = C.UNK_TOKEN,
179+
padding_token: typing.Optional[typing.Hashable] = C.PAD_TOKEN,
180+
bos_token: typing.Optional[typing.Hashable] = C.BOS_TOKEN,
181+
eos_token: typing.Optional[typing.Hashable] = C.EOS_TOKEN,
182+
reserved_tokens: typing.Optional[typing.List[typing.Hashable]] = None,
183+
token_to_idx: typing.Optional[typing.Dict[typing.Hashable, int]] = None, **kwargs):
181184

182185
# Sanity checks.
183186
assert min_freq > 0, '`min_freq` must be set to a positive value.'

0 commit comments

Comments
 (0)