Skip to content

Commit b72f352

Browse files
authored
[DATA] Remove repeated chars during preprocessing (#7739)
* add remove_repeated_chars * update max_repeated_len
1 parent 6f45e95 commit b72f352

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

model_zoo/ernie-1.0/preprocess/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,8 @@ common config:
175175
打印日志间隔,interval表示处理 文本行数/doc数的 间隔。
176176
--workers WORKERS Number of worker processes to launch
177177
处理文本id化的进程个数。
178+
--max_repeated_len Max length of repeated chars to keep
179+
最大保留重复的字符个数。
178180
```
179181
通过下面脚本转化,我们可以得到处理好的预训练数据,token ids:`baike_sample.bin`, 文章索引信息`baike_sample.idx`.
180182

model_zoo/ernie-1.0/preprocess/create_pretraining_data.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,9 @@ def get_args():
104104
group.add_argument("--log_interval", type=int, default=100, help="Interval between progress updates")
105105
group.add_argument("--workers", type=int, default=1, help="Number of worker processes to launch")
106106
group.add_argument("--max_doc_num", type=int, default=sys.maxsize, help="Number of worker processes to launch")
107+
group.add_argument(
108+
"--max_repeated_len", type=int, default=100, help="The maximum length of the repeated characters to keep"
109+
)
107110

108111
args = parser.parse_args()
109112
return args
@@ -278,8 +281,24 @@ def process(text):
278281

279282
Converter.process = process
280283

284+
def remove_repeated_chars(text, max_repeated_len=100):
285+
"""
286+
Removes repeated characters from the given text, where the length of
287+
the repeated characters is greater than or equal to the specified length.
288+
289+
Args:
290+
text (str): The input text from which to remove repeated characters.
291+
length (int, optional): The minimum length of the repeated characters. Defaults to 15.
292+
293+
Returns:
294+
str: The modified text with the repeated characters removed.
295+
"""
296+
pattern = r"(.)\1{" + str(max_repeated_len) + ",}"
297+
return re.sub(pattern, r"\1", text)
298+
281299
def encode(self, json_line):
282300
text = json.loads(json_line)[self.args.json_key]
301+
text = Converter.remove_repeated_chars(text, self.args.max_repeated_len)
283302
doc_ids = []
284303
for sentence in Converter.splitter.tokenize(text):
285304
sentence_ids = Converter.process(sentence.strip())

0 commit comments

Comments
 (0)