Skip to content

Commit 5a1c4ac

Browse files
authored
save extra special tokens (#9837)
1 parent 3900428 commit 5a1c4ac

File tree

1 file changed

+20
-2
lines changed

1 file changed

+20
-2
lines changed

paddlenlp/transformers/tokenizer_utils_base.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -967,6 +967,11 @@ def add_tokens(
967967

968968
return self._add_tokens(new_tokens, special_tokens=special_tokens)
969969

970+
@classmethod
971+
def _add_extra_special_tokens(cls, extra_sp_token: Union[str, AddedToken]):
972+
if extra_sp_token not in cls.SPECIAL_TOKENS_ATTRIBUTES:
973+
cls.SPECIAL_TOKENS_ATTRIBUTES.append(extra_sp_token)
974+
970975
def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
971976
raise NotImplementedError
972977

@@ -1213,7 +1218,13 @@ def special_tokens_map(self) -> Dict[str, Union[str, List[str]]]:
12131218
"""
12141219
set_attr = {}
12151220
for attr in self.SPECIAL_TOKENS_ATTRIBUTES:
1216-
attr_value = getattr(self, "_" + attr)
1221+
try:
1222+
attr_value = getattr(self, "_" + attr)
1223+
except:
1224+
try:
1225+
attr_value = getattr(self, attr)
1226+
except:
1227+
continue
12171228
if attr_value:
12181229
set_attr[attr] = (
12191230
type(attr_value)(str(attr_value_sub) for attr_value_sub in attr_value)
@@ -1233,7 +1244,13 @@ def special_tokens_map_extended(self) -> Dict[str, Union[str, AddedToken, List[U
12331244
"""
12341245
set_attr = {}
12351246
for attr in self.SPECIAL_TOKENS_ATTRIBUTES:
1236-
attr_value = getattr(self, "_" + attr)
1247+
try:
1248+
attr_value = getattr(self, "_" + attr)
1249+
except:
1250+
try:
1251+
attr_value = getattr(self, attr)
1252+
except:
1253+
continue
12371254
if attr_value:
12381255
set_attr[attr] = attr_value
12391256
return set_attr
@@ -1744,6 +1761,7 @@ def convert_added_tokens(obj):
17441761
elif isinstance(value, list):
17451762
value = [AddedToken(**token) if isinstance(token, dict) else token for token in value]
17461763
setattr(tokenizer, key, value)
1764+
cls._add_extra_special_tokens(key)
17471765

17481766
# Add supplementary tokens.
17491767
special_tokens = tokenizer.all_special_tokens

0 commit comments

Comments
 (0)