Skip to content

Commit 4306af2

Browse files
authored
[Typing][C-127,C-128] Add type annotations for python/paddle/audio/datasets/{esc50,tess}.py (#67067)
1 parent 5f3d472 commit 4306af2

File tree

2 files changed

+58
-29
lines changed

2 files changed

+58
-29
lines changed

python/paddle/audio/datasets/esc50.py

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,30 @@
1313
# limitations under the License.
1414
from __future__ import annotations
1515

16-
import collections
1716
import os
17+
from typing import TYPE_CHECKING, Any, Literal, NamedTuple
18+
19+
from typing_extensions import TypeAlias
1820

1921
from paddle.dataset.common import DATA_HOME
2022
from paddle.utils import download
2123

2224
from .dataset import AudioClassificationDataset
2325

26+
if TYPE_CHECKING:
27+
_ModeLiteral: TypeAlias = Literal[
28+
'train',
29+
'dev',
30+
]
31+
_FeatTypeLiteral: TypeAlias = Literal[
32+
'raw',
33+
'melspectrogram',
34+
'mfcc',
35+
'logmelspectrogram',
36+
'spectrogram',
37+
]
38+
39+
2440
__all__ = []
2541

2642

@@ -80,12 +96,12 @@ class ESC50(AudioClassificationDataset):
8096
8197
"""
8298

83-
archive = {
99+
archive: dict[str, str] = {
84100
'url': 'https://paddleaudio.bj.bcebos.com/datasets/ESC-50-master.zip',
85101
'md5': '7771e4b9d86d0945acce719c7a59305a',
86102
}
87103

88-
label_list = [
104+
label_list: list[str] = [
89105
# Animals
90106
'Dog',
91107
'Rooster',
@@ -142,21 +158,26 @@ class ESC50(AudioClassificationDataset):
142158
'Fireworks',
143159
'Hand saw',
144160
]
145-
meta = os.path.join('ESC-50-master', 'meta', 'esc50.csv')
146-
meta_info = collections.namedtuple(
147-
'META_INFO',
148-
('filename', 'fold', 'target', 'category', 'esc10', 'src_file', 'take'),
149-
)
150-
audio_path = os.path.join('ESC-50-master', 'audio')
161+
meta: str = os.path.join('ESC-50-master', 'meta', 'esc50.csv')
162+
audio_path: str = os.path.join('ESC-50-master', 'audio')
163+
164+
class meta_info(NamedTuple):
165+
filename: str
166+
fold: str
167+
target: str
168+
category: str
169+
esc10: str
170+
src_file: str
171+
take: str
151172

152173
def __init__(
153174
self,
154-
mode: str = 'train',
175+
mode: _ModeLiteral = 'train',
155176
split: int = 1,
156-
feat_type: str = 'raw',
157-
archive=None,
158-
**kwargs,
159-
):
177+
feat_type: _FeatTypeLiteral = 'raw',
178+
archive: dict[str, str] | None = None,
179+
**kwargs: Any,
180+
) -> None:
160181
assert split in range(
161182
1, 6
162183
), f'The selected split should be integer, and 1 <= split <= 5, but got {split}'
@@ -167,14 +188,16 @@ def __init__(
167188
files=files, labels=labels, feat_type=feat_type, **kwargs
168189
)
169190

170-
def _get_meta_info(self) -> list[collections.namedtuple]:
191+
def _get_meta_info(self) -> list[meta_info]:
171192
ret = []
172193
with open(os.path.join(DATA_HOME, self.meta), 'r') as rf:
173194
for line in rf.readlines()[1:]:
174195
ret.append(self.meta_info(*line.strip().split(',')))
175196
return ret
176197

177-
def _get_data(self, mode: str, split: int) -> tuple[list[str], list[int]]:
198+
def _get_data(
199+
self, mode: _ModeLiteral, split: int
200+
) -> tuple[list[str], list[int]]:
178201
if not os.path.isdir(
179202
os.path.join(DATA_HOME, self.audio_path)
180203
) or not os.path.isfile(os.path.join(DATA_HOME, self.meta)):

python/paddle/audio/datasets/tess.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,17 @@
1313
# limitations under the License.
1414
from __future__ import annotations
1515

16-
import collections
1716
import os
17+
from typing import TYPE_CHECKING, Any, NamedTuple
1818

1919
from paddle.dataset.common import DATA_HOME
2020
from paddle.utils import download
2121

2222
from .dataset import AudioClassificationDataset
2323

24+
if TYPE_CHECKING:
25+
from .esc50 import _FeatTypeLiteral, _ModeLiteral
26+
2427
__all__ = []
2528

2629

@@ -71,12 +74,12 @@ class TESS(AudioClassificationDataset):
7174
... # [feature_dim, num_frames] , label_id
7275
"""
7376

74-
archive = {
77+
archive: dict[str, str] = {
7578
'url': 'https://bj.bcebos.com/paddleaudio/datasets/TESS_Toronto_emotional_speech_set.zip',
7679
'md5': '1465311b24d1de704c4c63e4ccc470c7',
7780
}
7881

79-
label_list = [
82+
label_list: list[str] = [
8083
'angry',
8184
'disgust',
8285
'fear',
@@ -85,20 +88,23 @@ class TESS(AudioClassificationDataset):
8588
'ps', # pleasant surprise
8689
'sad',
8790
]
88-
meta_info = collections.namedtuple(
89-
'META_INFO', ('speaker', 'word', 'emotion')
90-
)
91-
audio_path = 'TESS_Toronto_emotional_speech_set'
91+
92+
audio_path: str = 'TESS_Toronto_emotional_speech_set'
93+
94+
class meta_info(NamedTuple):
95+
speaker: str
96+
word: str
97+
emotion: str
9298

9399
def __init__(
94100
self,
95-
mode: str = 'train',
101+
mode: _ModeLiteral = 'train',
96102
n_folds: int = 5,
97103
split: int = 1,
98-
feat_type: str = 'raw',
99-
archive=None,
100-
**kwargs,
101-
):
104+
feat_type: _FeatTypeLiteral = 'raw',
105+
archive: dict[str, str] | None = None,
106+
**kwargs: Any,
107+
) -> None:
102108
assert isinstance(n_folds, int) and (
103109
n_folds >= 1
104110
), f'the n_folds should be integer and n_folds >= 1, but got {n_folds}'
@@ -112,7 +118,7 @@ def __init__(
112118
files=files, labels=labels, feat_type=feat_type, **kwargs
113119
)
114120

115-
def _get_meta_info(self, files) -> list[collections.namedtuple]:
121+
def _get_meta_info(self, files) -> list[meta_info]:
116122
ret = []
117123
for file in files:
118124
basename_without_extend = os.path.basename(file)[:-4]

0 commit comments

Comments
 (0)