1313# limitations under the License.
1414from __future__ import annotations
1515
16- import collections
1716import os
17+ from typing import TYPE_CHECKING , Any , Literal , NamedTuple
18+
19+ from typing_extensions import TypeAlias
1820
1921from paddle .dataset .common import DATA_HOME
2022from paddle .utils import download
2123
2224from .dataset import AudioClassificationDataset
2325
26+ if TYPE_CHECKING :
27+ _ModeLiteral : TypeAlias = Literal [
28+ 'train' ,
29+ 'dev' ,
30+ ]
31+ _FeatTypeLiteral : TypeAlias = Literal [
32+ 'raw' ,
33+ 'melspectrogram' ,
34+ 'mfcc' ,
35+ 'logmelspectrogram' ,
36+ 'spectrogram' ,
37+ ]
38+
39+
2440__all__ = []
2541
2642
@@ -80,12 +96,12 @@ class ESC50(AudioClassificationDataset):
8096
8197 """
8298
83- archive = {
99+ archive : dict [ str , str ] = {
84100 'url' : 'https://paddleaudio.bj.bcebos.com/datasets/ESC-50-master.zip' ,
85101 'md5' : '7771e4b9d86d0945acce719c7a59305a' ,
86102 }
87103
88- label_list = [
104+ label_list : list [ str ] = [
89105 # Animals
90106 'Dog' ,
91107 'Rooster' ,
@@ -142,21 +158,26 @@ class ESC50(AudioClassificationDataset):
142158 'Fireworks' ,
143159 'Hand saw' ,
144160 ]
145- meta = os .path .join ('ESC-50-master' , 'meta' , 'esc50.csv' )
146- meta_info = collections .namedtuple (
147- 'META_INFO' ,
148- ('filename' , 'fold' , 'target' , 'category' , 'esc10' , 'src_file' , 'take' ),
149- )
150- audio_path = os .path .join ('ESC-50-master' , 'audio' )
161+ meta : str = os .path .join ('ESC-50-master' , 'meta' , 'esc50.csv' )
162+ audio_path : str = os .path .join ('ESC-50-master' , 'audio' )
163+
164+ class meta_info (NamedTuple ):
165+ filename : str
166+ fold : str
167+ target : str
168+ category : str
169+ esc10 : str
170+ src_file : str
171+ take : str
151172
152173 def __init__ (
153174 self ,
154- mode : str = 'train' ,
175+ mode : _ModeLiteral = 'train' ,
155176 split : int = 1 ,
156- feat_type : str = 'raw' ,
157- archive = None ,
158- ** kwargs ,
159- ):
177+ feat_type : _FeatTypeLiteral = 'raw' ,
178+ archive : dict [ str , str ] | None = None ,
179+ ** kwargs : Any ,
180+ ) -> None :
160181 assert split in range (
161182 1 , 6
162183 ), f'The selected split should be integer, and 1 <= split <= 5, but got { split } '
@@ -167,14 +188,16 @@ def __init__(
167188 files = files , labels = labels , feat_type = feat_type , ** kwargs
168189 )
169190
170- def _get_meta_info (self ) -> list [collections . namedtuple ]:
191+ def _get_meta_info (self ) -> list [meta_info ]:
171192 ret = []
172193 with open (os .path .join (DATA_HOME , self .meta ), 'r' ) as rf :
173194 for line in rf .readlines ()[1 :]:
174195 ret .append (self .meta_info (* line .strip ().split (',' )))
175196 return ret
176197
177- def _get_data (self , mode : str , split : int ) -> tuple [list [str ], list [int ]]:
198+ def _get_data (
199+ self , mode : _ModeLiteral , split : int
200+ ) -> tuple [list [str ], list [int ]]:
178201 if not os .path .isdir (
179202 os .path .join (DATA_HOME , self .audio_path )
180203 ) or not os .path .isfile (os .path .join (DATA_HOME , self .meta )):
0 commit comments