Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions model_zoo/ernie-1.0/preprocess/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ common config:
打印日志间隔,interval表示处理 文本行数/doc数的 间隔。
--workers WORKERS Number of worker processes to launch
处理文本id化的进程个数。
--max_repeated_len Max length of repeated chars to keep
最大保留重复的字符个数。
```
通过下面脚本转化,我们可以得到处理好的预训练数据,token ids:`baike_sample.bin`, 文章索引信息`baike_sample.idx`.

Expand Down
19 changes: 19 additions & 0 deletions model_zoo/ernie-1.0/preprocess/create_pretraining_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ def get_args():
group.add_argument("--log_interval", type=int, default=100, help="Interval between progress updates")
group.add_argument("--workers", type=int, default=1, help="Number of worker processes to launch")
group.add_argument("--max_doc_num", type=int, default=sys.maxsize, help="Number of worker processes to launch")
group.add_argument(
"--max_repeated_len", type=int, default=100, help="The maximum length of the repeated characters to keep"
)

args = parser.parse_args()
return args
Expand Down Expand Up @@ -278,8 +281,24 @@ def process(text):

Converter.process = process

def remove_repeated_chars(text, max_repeated_len=100):
"""
Removes repeated characters from the given text, where the length of
the repeated characters is greater than or equal to the specified length.

Args:
text (str): The input text from which to remove repeated characters.
length (int, optional): The minimum length of the repeated characters. Defaults to 15.

Returns:
str: The modified text with the repeated characters removed.
"""
pattern = r"(.)\1{" + str(max_repeated_len) + ",}"
return re.sub(pattern, r"\1", text)

def encode(self, json_line):
text = json.loads(json_line)[self.args.json_key]
text = Converter.remove_repeated_chars(text, self.args.max_repeated_len)
doc_ids = []
for sentence in Converter.splitter.tokenize(text):
sentence_ids = Converter.process(sentence.strip())
Expand Down