Skip to content

Commit 1b106e2

Browse files
committed
update
1 parent 5e874ee commit 1b106e2

File tree

5 files changed

+67
-30
lines changed

5 files changed

+67
-30
lines changed

paddlenlp/dataaug/base_augment.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class BaseAugment(object):
4141
Maximum number of augmented words in sequences.
4242
"""
4343

44-
def __init__(self, create_n=1, aug_n=None, aug_percent=0.1, aug_min=1, aug_max=10):
44+
def __init__(self, create_n=1, aug_n=None, aug_percent=0.1, aug_min=1, aug_max=10, vocab="vocab"):
4545
self._DATA = {
4646
"stop_words": (
4747
"stopwords.txt",
@@ -53,6 +53,11 @@ def __init__(self, create_n=1, aug_n=None, aug_percent=0.1, aug_min=1, aug_max=1
5353
"25c2d41aec5a6d328a65c1995d4e4c2e",
5454
"https://bj.bcebos.com/paddlenlp/data/baidu_encyclopedia_w2v_vocab.json",
5555
),
56+
"test_vocab": (
57+
"test_vocab.json",
58+
"1d2fce1c80a4a0ec2e90a136f339ab88",
59+
"https://bj.bcebos.com/paddlenlp/data/test_vocab.json",
60+
),
5661
"word_synonym": (
5762
"word_synonym.json",
5863
"aaa9f864b4af4123bce4bf138a5bfa0d",
@@ -90,7 +95,7 @@ def __init__(self, create_n=1, aug_n=None, aug_percent=0.1, aug_min=1, aug_max=1
9095
self.aug_min = aug_min
9196
self.aug_max = aug_max
9297
self.create_n = create_n
93-
self.vocab = Vocab.from_json(self._load_file("vocab"))
98+
self.vocab = Vocab.from_json(self._load_file(vocab))
9499
self.tokenizer = JiebaTokenizer(self.vocab)
95100
self.loop = 5
96101

paddlenlp/dataaug/char.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,11 @@ def __init__(
6363
aug_min=1,
6464
aug_max=10,
6565
model_name="ernie-1.0-large-zh-cw",
66+
vocab="vocab",
6667
):
67-
super().__init__(create_n=create_n, aug_n=aug_n, aug_percent=aug_percent, aug_min=aug_min, aug_max=aug_max)
68+
super().__init__(
69+
create_n=create_n, aug_n=aug_n, aug_percent=aug_percent, aug_min=aug_min, aug_max=aug_max, vocab=vocab
70+
)
6871

6972
self.custom_file_path = custom_file_path
7073
self.delete_file_path = delete_file_path
@@ -275,8 +278,11 @@ def __init__(
275278
aug_min=1,
276279
aug_max=10,
277280
model_name="ernie-1.0-large-zh-cw",
281+
vocab="vocab",
278282
):
279-
super().__init__(create_n=create_n, aug_n=aug_n, aug_percent=aug_percent, aug_min=aug_min, aug_max=aug_max)
283+
super().__init__(
284+
create_n=create_n, aug_n=aug_n, aug_percent=aug_percent, aug_min=aug_min, aug_max=aug_max, vocab=vocab
285+
)
280286

281287
self.custom_file_path = custom_file_path
282288
self.delete_file_path = delete_file_path
@@ -457,8 +463,10 @@ class CharSwap(BaseAugment):
457463
Maximum number of augmented characters in sequences.
458464
"""
459465

460-
def __init__(self, create_n=1, aug_n=None, aug_percent=None, aug_min=1, aug_max=10):
461-
super().__init__(create_n=create_n, aug_n=aug_n, aug_percent=0.1, aug_min=aug_min, aug_max=aug_max)
466+
def __init__(self, create_n=1, aug_n=None, aug_percent=None, aug_min=1, aug_max=10, vocab="vocab"):
467+
super().__init__(
468+
create_n=create_n, aug_n=aug_n, aug_percent=0.1, aug_min=aug_min, aug_max=aug_max, vocab=vocab
469+
)
462470

463471
def _augment(self, sequence):
464472

@@ -521,8 +529,10 @@ class CharDelete(BaseAugment):
521529
Maximum number of augmented characters in sequences.
522530
"""
523531

524-
def __init__(self, create_n=1, aug_n=None, aug_percent=0.1, aug_min=1, aug_max=10):
525-
super().__init__(create_n=create_n, aug_n=aug_n, aug_percent=aug_percent, aug_min=aug_min, aug_max=aug_max)
532+
def __init__(self, create_n=1, aug_n=None, aug_percent=0.1, aug_min=1, aug_max=10, vocab="vocab"):
533+
super().__init__(
534+
create_n=create_n, aug_n=aug_n, aug_percent=aug_percent, aug_min=aug_min, aug_max=aug_max, vocab=vocab
535+
)
526536

527537
def _augment(self, sequence):
528538

paddlenlp/dataaug/word.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,11 @@ def __init__(
7070
tf_idf=False,
7171
tf_idf_file=None,
7272
model_name="ernie-1.0-large-zh-cw",
73+
vocab="vocab",
7374
):
74-
super().__init__(create_n=create_n, aug_n=aug_n, aug_percent=aug_percent, aug_min=aug_min, aug_max=aug_max)
75+
super().__init__(
76+
create_n=create_n, aug_n=aug_n, aug_percent=aug_percent, aug_min=aug_min, aug_max=aug_max, vocab=vocab
77+
)
7578

7679
self.custom_file_path = custom_file_path
7780
self.delete_file_path = delete_file_path
@@ -341,8 +344,11 @@ def __init__(
341344
aug_min=1,
342345
aug_max=10,
343346
model_name="ernie-1.0-large-zh-cw",
347+
vocab="vocab",
344348
):
345-
super().__init__(create_n=create_n, aug_n=aug_n, aug_percent=aug_percent, aug_min=aug_min, aug_max=aug_max)
349+
super().__init__(
350+
create_n=create_n, aug_n=aug_n, aug_percent=aug_percent, aug_min=aug_min, aug_max=aug_max, vocab=vocab
351+
)
346352

347353
self.custom_file_path = custom_file_path
348354
self.delete_file_path = delete_file_path
@@ -524,8 +530,10 @@ class WordSwap(BaseAugment):
524530
Maximum number of augmented words in sequences.
525531
"""
526532

527-
def __init__(self, create_n=1, aug_n=None, aug_percent=None, aug_min=1, aug_max=10):
528-
super().__init__(create_n=create_n, aug_n=aug_n, aug_percent=0.1, aug_min=aug_min, aug_max=aug_max)
533+
def __init__(self, create_n=1, aug_n=None, aug_percent=None, aug_min=1, aug_max=10, vocab="vocab"):
534+
super().__init__(
535+
create_n=create_n, aug_n=aug_n, aug_percent=0.1, aug_min=aug_min, aug_max=aug_max, vocab=vocab
536+
)
529537

530538
def _augment(self, sequence):
531539

@@ -588,8 +596,10 @@ class WordDelete(BaseAugment):
588596
Maximum number of augmented words in sequences.
589597
"""
590598

591-
def __init__(self, create_n=1, aug_n=None, aug_percent=0.1, aug_min=1, aug_max=10):
592-
super().__init__(create_n=create_n, aug_n=aug_n, aug_percent=aug_percent, aug_min=aug_min, aug_max=aug_max)
599+
def __init__(self, create_n=1, aug_n=None, aug_percent=0.1, aug_min=1, aug_max=10, vocab="vocab"):
600+
super().__init__(
601+
create_n=create_n, aug_n=aug_n, aug_percent=aug_percent, aug_min=aug_min, aug_max=aug_max, vocab=vocab
602+
)
593603

594604
def _augment(self, sequence):
595605

tests/dataaug/test_char_aug.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,18 @@ def set_random_seed(self, seed):
5757
def test_char_substitute(self, create_n):
5858
for t in self.types:
5959
if t == "mlm":
60-
aug = CharSubstitute("mlm", create_n=create_n, model_name="__internal_testing__/ernie")
60+
aug = CharSubstitute(
61+
"mlm", create_n=create_n, model_name="__internal_testing__/ernie", vocab="test_vocab"
62+
)
6163
augmented = aug.augment(self.sequences)
6264
self.assertEqual(len(self.sequences), len(augmented))
6365
continue
6466
elif t == "custom":
65-
aug = CharSubstitute("custom", create_n=create_n, custom_file_path=self.custom_file_path)
67+
aug = CharSubstitute(
68+
"custom", create_n=create_n, custom_file_path=self.custom_file_path, vocab="test_vocab"
69+
)
6670
else:
67-
aug = CharSubstitute(t, create_n=create_n)
71+
aug = CharSubstitute(t, create_n=create_n, vocab="test_vocab")
6872

6973
augmented = aug.augment(self.sequences)
7074
self.assertEqual(len(self.sequences), len(augmented))
@@ -75,14 +79,16 @@ def test_char_substitute(self, create_n):
7579
def test_char_insert(self, create_n):
7680
for t in self.types:
7781
if t == "mlm":
78-
aug = CharInsert("mlm", create_n=create_n, model_name="__internal_testing__/ernie")
82+
aug = CharInsert("mlm", create_n=create_n, model_name="__internal_testing__/ernie", vocab="test_vocab")
7983
augmented = aug.augment(self.sequences)
8084
self.assertEqual(len(self.sequences), len(augmented))
8185
continue
8286
elif t == "custom":
83-
aug = CharInsert("custom", create_n=create_n, custom_file_path=self.custom_file_path)
87+
aug = CharInsert(
88+
"custom", create_n=create_n, custom_file_path=self.custom_file_path, vocab="test_vocab"
89+
)
8490
else:
85-
aug = CharInsert(t, create_n=create_n)
91+
aug = CharInsert(t, create_n=create_n, vocab="test_vocab")
8692

8793
augmented = aug.augment(self.sequences)
8894
self.assertEqual(len(self.sequences), len(augmented))
@@ -91,15 +97,15 @@ def test_char_insert(self, create_n):
9197

9298
@parameterized.expand([(1,)])
9399
def test_char_delete(self, create_n):
94-
aug = CharDelete(create_n=create_n)
100+
aug = CharDelete(create_n=create_n, vocab="test_vocab")
95101
augmented = aug.augment(self.sequences)
96102
self.assertEqual(len(self.sequences), len(augmented))
97103
self.assertEqual(create_n, len(augmented[0]))
98104
self.assertEqual(create_n, len(augmented[1]))
99105

100106
@parameterized.expand([(1,)])
101107
def test_char_swap(self, create_n):
102-
aug = CharSwap(create_n=create_n)
108+
aug = CharSwap(create_n=create_n, vocab="test_vocab")
103109
augmented = aug.augment(self.sequences)
104110
self.assertEqual(len(self.sequences), len(augmented))
105111
self.assertEqual(create_n, len(augmented[0]))

tests/dataaug/test_word_aug.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,18 @@ def set_random_seed(self, seed):
5050
def test_word_substitute(self, create_n):
5151
for t in self.types:
5252
if t == "mlm":
53-
aug = WordSubstitute("mlm", create_n=create_n, model_name="__internal_testing__/ernie")
53+
aug = WordSubstitute(
54+
"mlm", create_n=create_n, model_name="__internal_testing__/ernie", vocab="test_vocab"
55+
)
5456
augmented = aug.augment(self.sequences)
5557
self.assertEqual(len(self.sequences), len(augmented))
5658
continue
5759
elif t == "custom":
58-
aug = WordSubstitute("custom", create_n=create_n, custom_file_path=self.custom_file_path)
60+
aug = WordSubstitute(
61+
"custom", create_n=create_n, custom_file_path=self.custom_file_path, vocab="test_vocab"
62+
)
5963
else:
60-
aug = WordSubstitute(t, create_n=create_n)
64+
aug = WordSubstitute(t, create_n=create_n, vocab="test_vocab")
6165

6266
augmented = aug.augment(self.sequences)
6367
self.assertEqual(len(self.sequences), len(augmented))
@@ -68,14 +72,16 @@ def test_word_substitute(self, create_n):
6872
def test_word_insert(self, create_n):
6973
for t in self.types:
7074
if t == "mlm":
71-
aug = WordInsert("mlm", create_n=create_n, model_name="__internal_testing__/ernie")
75+
aug = WordInsert("mlm", create_n=create_n, model_name="__internal_testing__/ernie", vocab="test_vocab")
7276
augmented = aug.augment(self.sequences)
7377
self.assertEqual(len(self.sequences), len(augmented))
7478
continue
7579
elif t == "custom":
76-
aug = WordInsert("custom", create_n=create_n, custom_file_path=self.custom_file_path)
80+
aug = WordInsert(
81+
"custom", create_n=create_n, custom_file_path=self.custom_file_path, vocab="test_vocab"
82+
)
7783
else:
78-
aug = WordInsert(t, create_n=create_n)
84+
aug = WordInsert(t, create_n=create_n, vocab="test_vocab")
7985

8086
augmented = aug.augment(self.sequences)
8187
self.assertEqual(len(self.sequences), len(augmented))
@@ -84,15 +90,15 @@ def test_word_insert(self, create_n):
8490

8591
@parameterized.expand([(1,)])
8692
def test_word_delete(self, create_n):
87-
aug = WordDelete(create_n=create_n)
93+
aug = WordDelete(create_n=create_n, vocab="test_vocab")
8894
augmented = aug.augment(self.sequences)
8995
self.assertEqual(len(self.sequences), len(augmented))
9096
self.assertEqual(create_n, len(augmented[0]))
9197
self.assertEqual(create_n, len(augmented[1]))
9298

9399
@parameterized.expand([(1,)])
94100
def test_word_swap(self, create_n):
95-
aug = WordSwap(create_n=create_n)
101+
aug = WordSwap(create_n=create_n, vocab="test_vocab")
96102
augmented = aug.augment(self.sequences)
97103
self.assertEqual(len(self.sequences), len(augmented))
98104
self.assertEqual(create_n, len(augmented[0]))

0 commit comments

Comments
 (0)