1- # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
1+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
22#
33# Licensed under the Apache License, Version 2.0 (the "License");
44# you may not use this file except in compliance with the License.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ import math
1516import os
1617import re
17- import math
18- import random
1918from typing import Iterable
2019
21- import numpy as np
22- import paddle
2320from paddle .dataset .common import md5file
2421from paddle .utils .download import get_path_from_url
2522
23+ from ..data import JiebaTokenizer , Vocab
2624from ..utils .env import DATA_HOME
27- from ..data import Vocab , JiebaTokenizer
2825
2926
3027class BaseAugment (object ):
@@ -44,7 +41,7 @@ class BaseAugment(object):
4441 Maximum number of augmented words in sequences.
4542 """
4643
47- def __init__ (self , create_n , aug_n = None , aug_percent = 0.02 , aug_min = 1 , aug_max = 10 ):
44+ def __init__ (self , create_n = 1 , aug_n = None , aug_percent = 0.1 , aug_min = 1 , aug_max = 10 , vocab = "vocab" ):
4845 self ._DATA = {
4946 "stop_words" : (
5047 "stopwords.txt" ,
@@ -56,24 +53,49 @@ def __init__(self, create_n, aug_n=None, aug_percent=0.02, aug_min=1, aug_max=10
5653 "25c2d41aec5a6d328a65c1995d4e4c2e" ,
5754 "https://bj.bcebos.com/paddlenlp/data/baidu_encyclopedia_w2v_vocab.json" ,
5855 ),
56+ "test_vocab" : (
57+ "test_vocab.json" ,
58+ "1d2fce1c80a4a0ec2e90a136f339ab88" ,
59+ "https://bj.bcebos.com/paddlenlp/data/test_vocab.json" ,
60+ ),
5961 "word_synonym" : (
6062 "word_synonym.json" ,
6163 "aaa9f864b4af4123bce4bf138a5bfa0d" ,
6264 "https://bj.bcebos.com/paddlenlp/data/word_synonym.json" ,
6365 ),
66+ "word_embedding" : (
67+ "word_embedding.json" ,
68+ "534aa4ad274def4deff585cefd8ead32" ,
69+ "https://bj.bcebos.com/paddlenlp/data/word_embedding.json" ,
70+ ),
6471 "word_homonym" : (
6572 "word_homonym.json" ,
6673 "a578c04201a697e738f6a1ad555787d5" ,
6774 "https://bj.bcebos.com/paddlenlp/data/word_homonym.json" ,
6875 ),
76+ "char_homonym" : (
77+ "char_homonym.json" ,
78+ "dd98d5d5d32a3d3dd45c8f7ca503c7df" ,
79+ "https://bj.bcebos.com/paddlenlp/data/char_homonym.json" ,
80+ ),
81+ "char_antonym" : (
82+ "char_antonym.json" ,
83+ "f892f5dce06f17d19949ebcbe0ed52b7" ,
84+ "https://bj.bcebos.com/paddlenlp/data/char_antonym.json" ,
85+ ),
86+ "word_antonym" : (
87+ "word_antonym.json" ,
88+ "cbea11fa99fbe9d07e8185750b37e84a" ,
89+ "https://bj.bcebos.com/paddlenlp/data/word_antonym.json" ,
90+ ),
6991 }
7092 self .stop_words = self ._get_data ("stop_words" )
7193 self .aug_n = aug_n
7294 self .aug_percent = aug_percent
7395 self .aug_min = aug_min
7496 self .aug_max = aug_max
7597 self .create_n = create_n
76- self .vocab = Vocab .from_json (self ._load_file (" vocab" ))
98+ self .vocab = Vocab .from_json (self ._load_file (vocab ))
7799 self .tokenizer = JiebaTokenizer (self .vocab )
78100 self .loop = 5
79101
@@ -150,7 +172,7 @@ def augment(self, sequences, num_thread=1):
150172 # Single Thread
151173 if num_thread == 1 :
152174 if isinstance (sequences , str ):
153- return self ._augment (sequences )
175+ return [ self ._augment (sequences )]
154176 else :
155177 output = []
156178 for sequence in sequences :
@@ -161,3 +183,59 @@ def augment(self, sequences, num_thread=1):
161183
162184 def _augment (self , sequence ):
163185 raise NotImplementedError
186+
187+
188+ class FileAugment (object ):
189+ """
190+ File data augmentation
191+
192+ Args:
193+ strategies (List):
194+ List of augmentation strategies.
195+ """
196+
197+ def __init__ (self , strategies ):
198+ self .strategies = strategies
199+
200+ def augment (self , input_file , output_file = "aug.txt" , separator = None , separator_id = 0 ):
201+ output_sequences = []
202+ sequences = []
203+
204+ input_sequences = self .file_read (input_file )
205+
206+ if separator :
207+ for input_sequence in input_sequences :
208+ sequences .append (input_sequence .split (separator )[separator_id ])
209+ else :
210+ sequences = input_sequences
211+
212+ for strategy in self .strategies :
213+ aug_sequences = strategy .augment (sequences )
214+ if separator :
215+ for aug_sequence , input_sequence in zip (aug_sequences , input_sequences ):
216+ input_items = input_sequence .split (separator )
217+ for s in aug_sequence :
218+ input_items [separator_id ] = s
219+ output_sequences .append (separator .join (input_items ))
220+ else :
221+ for aug_sequence in aug_sequences :
222+ output_sequences += aug_sequence
223+
224+ if output_file :
225+ self .file_write (output_sequences , output_file )
226+
227+ return output_sequences
228+
229+ def file_read (self , input_file ):
230+ input_sequences = []
231+ with open (input_file , "r" , encoding = "utf-8" ) as f :
232+ for line in f :
233+ input_sequences .append (line .strip ())
234+ f .close ()
235+ return input_sequences
236+
237+ def file_write (self , output_sequences , output_file ):
238+ with open (output_file , "w" , encoding = "utf-8" ) as f :
239+ for output_sequence in output_sequences :
240+ f .write (output_sequence + "\n " )
241+ f .close ()
0 commit comments