@@ -57,14 +57,18 @@ def set_random_seed(self, seed):
5757 def test_char_substitute (self , create_n ):
5858 for t in self .types :
5959 if t == "mlm" :
60- aug = CharSubstitute ("mlm" , create_n = create_n , model_name = "__internal_testing__/ernie" )
60+ aug = CharSubstitute (
61+ "mlm" , create_n = create_n , model_name = "__internal_testing__/ernie" , vocab = "test_vocab"
62+ )
6163 augmented = aug .augment (self .sequences )
6264 self .assertEqual (len (self .sequences ), len (augmented ))
6365 continue
6466 elif t == "custom" :
65- aug = CharSubstitute ("custom" , create_n = create_n , custom_file_path = self .custom_file_path )
67+ aug = CharSubstitute (
68+ "custom" , create_n = create_n , custom_file_path = self .custom_file_path , vocab = "test_vocab"
69+ )
6670 else :
67- aug = CharSubstitute (t , create_n = create_n )
71+ aug = CharSubstitute (t , create_n = create_n , vocab = "test_vocab" )
6872
6973 augmented = aug .augment (self .sequences )
7074 self .assertEqual (len (self .sequences ), len (augmented ))
@@ -75,14 +79,16 @@ def test_char_substitute(self, create_n):
7579 def test_char_insert (self , create_n ):
7680 for t in self .types :
7781 if t == "mlm" :
78- aug = CharInsert ("mlm" , create_n = create_n , model_name = "__internal_testing__/ernie" )
82+ aug = CharInsert ("mlm" , create_n = create_n , model_name = "__internal_testing__/ernie" , vocab = "test_vocab" )
7983 augmented = aug .augment (self .sequences )
8084 self .assertEqual (len (self .sequences ), len (augmented ))
8185 continue
8286 elif t == "custom" :
83- aug = CharInsert ("custom" , create_n = create_n , custom_file_path = self .custom_file_path )
87+ aug = CharInsert (
88+ "custom" , create_n = create_n , custom_file_path = self .custom_file_path , vocab = "test_vocab"
89+ )
8490 else :
85- aug = CharInsert (t , create_n = create_n )
91+ aug = CharInsert (t , create_n = create_n , vocab = "test_vocab" )
8692
8793 augmented = aug .augment (self .sequences )
8894 self .assertEqual (len (self .sequences ), len (augmented ))
@@ -91,15 +97,15 @@ def test_char_insert(self, create_n):
9197
9298 @parameterized .expand ([(1 ,)])
9399 def test_char_delete (self , create_n ):
94- aug = CharDelete (create_n = create_n )
100+ aug = CharDelete (create_n = create_n , vocab = "test_vocab" )
95101 augmented = aug .augment (self .sequences )
96102 self .assertEqual (len (self .sequences ), len (augmented ))
97103 self .assertEqual (create_n , len (augmented [0 ]))
98104 self .assertEqual (create_n , len (augmented [1 ]))
99105
100106 @parameterized .expand ([(1 ,)])
101107 def test_char_swap (self , create_n ):
102- aug = CharSwap (create_n = create_n )
108+ aug = CharSwap (create_n = create_n , vocab = "test_vocab" )
103109 augmented = aug .augment (self .sequences )
104110 self .assertEqual (len (self .sequences ), len (augmented ))
105111 self .assertEqual (create_n , len (augmented [0 ]))
0 commit comments