@@ -30,7 +30,6 @@ def __init__(self,
30
30
self .zh_frontend = Frontend (
31
31
phone_vocab_path = phone_vocab_path , tone_vocab_path = tone_vocab_path )
32
32
self .en_frontend = English (phone_vocab_path = phone_vocab_path )
33
- self .SENTENCE_SPLITOR = re .compile (r'([:、,;。?!,;?!][”’]?)' )
34
33
self .sp_id = self .zh_frontend .vocab_phones ["sp" ]
35
34
self .sp_id_tensor = paddle .to_tensor ([self .sp_id ])
36
35
@@ -47,114 +46,59 @@ def is_alphabet(self, char):
47
46
else :
48
47
return False
49
48
50
- def is_number (self , char ):
51
- if char >= '\u0030 ' and char <= '\u0039 ' :
52
- return True
53
- else :
54
- return False
55
-
56
49
def is_other (self , char ):
57
- if not (self .is_chinese (char ) or self .is_number (char ) or
58
- self .is_alphabet (char )):
50
+ if not (self .is_chinese (char ) or self .is_alphabet (char )):
59
51
return True
60
52
else :
61
53
return False
62
54
63
- def _replace (self , text : str ) -> str :
64
- new_text = text
65
-
66
- # get "." indexs
67
- point_indexs = []
68
- index = - 1
69
- for i in range (text .count ("." )):
70
- index = text .find ("." , index + 1 , len (text ))
71
- point_indexs .append (index )
72
-
73
- # replace
74
- if len (point_indexs ) != 0 :
75
- for index in point_indexs :
76
- ch = text [index - 1 ]
77
- if self .is_alphabet (ch ) or ch == " " :
78
- new_text = new_text [:index ] + "。" + new_text [index + 1 :]
79
-
80
- return new_text
81
-
82
- def _split (self , text : str ) -> List [str ]:
83
- text = re .sub (r'[《》【】<=>{}()()#&@“”^_|…\\]' , '' , text )
84
- # 替换英文句子的句号 "." --> "。" 用于后续分句
85
- text = self ._replace (text )
86
- text = self .SENTENCE_SPLITOR .sub (r'\1\n' , text )
87
- text = text .strip ()
88
- sentences = [sentence .strip () for sentence in re .split (r'\n+' , text )]
89
- return sentences
90
-
91
- def _distinguish (self , text : str ) -> List [str ]:
55
+
56
+ def get_segment (self , text : str ) -> List [str ]:
92
57
# sentence --> [ch_part, en_part, ch_part, ...]
93
58
94
59
segments = []
95
60
types = []
96
-
97
61
flag = 0
98
62
temp_seg = ""
99
63
temp_lang = ""
100
64
101
65
# Determine the type of each character. type: blank, chinese, alphabet, number, unk and point.
102
66
for ch in text :
103
- if ch == "." :
104
- types .append ("point" )
105
- elif self .is_chinese (ch ):
67
+ if self .is_chinese (ch ):
106
68
types .append ("zh" )
107
69
elif self .is_alphabet (ch ):
108
70
types .append ("en" )
109
- elif ch == " " :
110
- types .append ("blank" )
111
- elif self .is_number (ch ):
112
- types .append ("num" )
113
71
else :
114
- types .append ("unk " )
72
+ types .append ("other " )
115
73
116
74
assert len (types ) == len (text )
117
75
118
76
for i in range (len (types )):
119
77
120
78
# find the first char of the seg
121
79
if flag == 0 :
122
- # 首个字符是中文,英文或者数字
123
- if types [i ] == "zh" or types [i ] == "en" or types [i ] == "num" :
124
- temp_seg += text [i ]
125
- temp_lang = types [i ]
126
- flag = 1
80
+ temp_seg += text [i ]
81
+ temp_lang = types [i ]
82
+ flag = 1
127
83
128
84
else :
129
- # 数字和小数点均与前面的字符合并,类型属于前面一个字符的类型
130
- if types [i ] == temp_lang or types [i ] == "num" or types [
131
- i ] == "point" :
132
- temp_seg += text [i ]
133
-
134
- # 数字与后面的任意字符都拼接
135
- elif temp_lang == "num" :
136
- temp_seg += text [i ]
137
- if types [i ] == "zh" or types [i ] == "en" :
85
+ if temp_lang == "other" :
86
+ if types [i ] == temp_lang :
87
+ temp_seg += text [i ]
88
+ else :
89
+ temp_seg += text [i ]
138
90
temp_lang = types [i ]
139
91
140
- # 如果是空格则与前面字符拼接
141
- elif types [i ] == "blank" :
142
- temp_seg += text [i ]
143
-
144
- elif types [i ] == "unk" :
145
- pass
146
-
147
92
else :
148
- segments .append ((temp_seg , temp_lang ))
149
-
150
- if types [i ] == "zh" or types [i ] == "en" :
93
+ if types [i ] == temp_lang :
94
+ temp_seg += text [i ]
95
+ elif types [i ] == "other" :
96
+ temp_seg += text [i ]
97
+ else :
98
+ segments .append ((temp_seg , temp_lang ))
151
99
temp_seg = text [i ]
152
100
temp_lang = types [i ]
153
101
flag = 1
154
- else :
155
- flag = 0
156
- temp_seg = ""
157
- temp_lang = ""
158
102
159
103
segments .append ((temp_seg , temp_lang ))
160
104
@@ -167,35 +111,30 @@ def get_input_ids(self,
167
111
add_sp : bool = True ,
168
112
to_tensor : bool = True ) -> Dict [str , List [paddle .Tensor ]]:
169
113
170
- sentences = self ._split (sentence )
114
+ segments = self .get_segment (sentence )
115
+
171
116
phones_list = []
172
117
result = {}
173
- for text in sentences :
174
- phones_seg = []
175
- segments = self ._distinguish (text )
176
- for seg in segments :
177
- content = seg [0 ]
178
- lang = seg [1 ]
179
- if content != '' :
180
- if lang == "en" :
181
- input_ids = self .en_frontend .get_input_ids (
182
- content , merge_sentences = True , to_tensor = to_tensor )
183
- else :
184
- input_ids = self .zh_frontend .get_input_ids (
185
- content ,
186
- merge_sentences = True ,
187
- get_tone_ids = get_tone_ids ,
188
- to_tensor = to_tensor )
189
-
190
- phones_seg .append (input_ids ["phone_ids" ][0 ])
191
- if add_sp :
192
- phones_seg .append (self .sp_id_tensor )
193
-
194
- if phones_seg == []:
195
- phones_seg .append (self .sp_id_tensor )
196
- phones = paddle .concat (phones_seg )
197
- phones_list .append (phones )
198
118
119
+ for seg in segments :
120
+ content = seg [0 ]
121
+ lang = seg [1 ]
122
+ if content != '' :
123
+ if lang == "en" :
124
+ input_ids = self .en_frontend .get_input_ids (
125
+ content , merge_sentences = False , to_tensor = to_tensor )
126
+ else :
127
+ input_ids = self .zh_frontend .get_input_ids (
128
+ content ,
129
+ merge_sentences = False ,
130
+ get_tone_ids = get_tone_ids ,
131
+ to_tensor = to_tensor )
132
+ if add_sp :
133
+ input_ids ["phone_ids" ][- 1 ] = paddle .concat ([input_ids ["phone_ids" ][- 1 ], self .sp_id_tensor ])
134
+
135
+ for phones in input_ids ["phone_ids" ]:
136
+ phones_list .append (phones )
137
+
199
138
if merge_sentences :
200
139
merge_list = paddle .concat (phones_list )
201
140
# rm the last 'sp' to avoid the noise at the end
0 commit comments