@@ -31,7 +31,7 @@ class TextGenerationDataset(RegistryMixin):
31
31
3. Tokenize dataset using model tokenizer/processor
32
32
4. Apply post processing such as grouping text and/or adding labels for finetuning
33
33
34
- :param data_args : configuration settings for dataset loading
34
+ :param dataset_args : configuration settings for dataset loading
35
35
:param split: split from dataset to load, for instance `test` or `train[:5%]`
36
36
:param processor: processor or tokenizer to use on dataset
37
37
"""
@@ -41,11 +41,11 @@ class TextGenerationDataset(RegistryMixin):
41
41
42
42
def __init__ (
43
43
self ,
44
- data_args : DatasetArguments ,
44
+ dataset_args : DatasetArguments ,
45
45
split : str ,
46
46
processor : Processor ,
47
47
):
48
- self .data_args = data_args
48
+ self .dataset_args = dataset_args
49
49
self .split = split
50
50
self .processor = processor
51
51
@@ -58,23 +58,23 @@ def __init__(
58
58
self .tokenizer .pad_token = self .tokenizer .eos_token
59
59
60
60
# configure sequence length
61
- max_seq_length = data_args .max_seq_length
62
- if data_args .max_seq_length > self .tokenizer .model_max_length :
61
+ max_seq_length = dataset_args .max_seq_length
62
+ if dataset_args .max_seq_length > self .tokenizer .model_max_length :
63
63
logger .warning (
64
64
f"The max_seq_length passed ({ max_seq_length } ) is larger than "
65
65
f"maximum length for model ({ self .tokenizer .model_max_length } ). "
66
66
f"Using max_seq_length={ self .tokenizer .model_max_length } ."
67
67
)
68
68
self .max_seq_length = min (
69
- data_args .max_seq_length , self .tokenizer .model_max_length
69
+ dataset_args .max_seq_length , self .tokenizer .model_max_length
70
70
)
71
71
72
72
# configure padding
73
73
self .padding = (
74
74
False
75
- if self .data_args .concatenate_data
75
+ if self .dataset_args .concatenate_data
76
76
else "max_length"
77
- if self .data_args .pad_to_max_length
77
+ if self .dataset_args .pad_to_max_length
78
78
else False
79
79
)
80
80
@@ -83,7 +83,7 @@ def __init__(
83
83
self .padding = False
84
84
85
85
def __call__ (self , add_labels : bool = True ) -> DatasetType :
86
- dataset = self .data_args .dataset
86
+ dataset = self .dataset_args .dataset
87
87
88
88
if isinstance (dataset , str ):
89
89
# load dataset: load from huggingface or disk
@@ -96,8 +96,8 @@ def __call__(self, add_labels: bool = True) -> DatasetType:
96
96
dataset ,
97
97
self .preprocess ,
98
98
batched = False ,
99
- num_proc = self .data_args .preprocessing_num_workers ,
100
- load_from_cache_file = not self .data_args .overwrite_cache ,
99
+ num_proc = self .dataset_args .preprocessing_num_workers ,
100
+ load_from_cache_file = not self .dataset_args .overwrite_cache ,
101
101
desc = "Preprocessing" ,
102
102
)
103
103
logger .debug (f"Dataset after preprocessing: { get_columns (dataset )} " )
@@ -121,20 +121,20 @@ def __call__(self, add_labels: bool = True) -> DatasetType:
121
121
# regardless of `batched` argument
122
122
remove_columns = get_columns (dataset ), # assumes that input names
123
123
# and output names are disjoint
124
- num_proc = self .data_args .preprocessing_num_workers ,
125
- load_from_cache_file = not self .data_args .overwrite_cache ,
124
+ num_proc = self .dataset_args .preprocessing_num_workers ,
125
+ load_from_cache_file = not self .dataset_args .overwrite_cache ,
126
126
desc = "Tokenizing" ,
127
127
)
128
128
logger .debug (f"Model kwargs after tokenizing: { get_columns (dataset )} " )
129
129
130
- if self .data_args .concatenate_data :
130
+ if self .dataset_args .concatenate_data :
131
131
# postprocess: group text
132
132
dataset = self .map (
133
133
dataset ,
134
134
self .group_text ,
135
135
batched = True ,
136
- num_proc = self .data_args .preprocessing_num_workers ,
137
- load_from_cache_file = not self .data_args .overwrite_cache ,
136
+ num_proc = self .dataset_args .preprocessing_num_workers ,
137
+ load_from_cache_file = not self .dataset_args .overwrite_cache ,
138
138
desc = "Concatenating data" ,
139
139
)
140
140
logger .debug (f"Model kwargs after concatenating: { get_columns (dataset )} " )
@@ -145,8 +145,8 @@ def __call__(self, add_labels: bool = True) -> DatasetType:
145
145
dataset ,
146
146
self .add_labels ,
147
147
batched = False , # not compatible with batching, need row lengths
148
- num_proc = self .data_args .preprocessing_num_workers ,
149
- load_from_cache_file = not self .data_args .overwrite_cache ,
148
+ num_proc = self .dataset_args .preprocessing_num_workers ,
149
+ load_from_cache_file = not self .dataset_args .overwrite_cache ,
150
150
desc = "Adding labels" ,
151
151
)
152
152
logger .debug (f"Model kwargs after adding labels: { get_columns (dataset )} " )
@@ -165,27 +165,31 @@ def load_dataset(self):
165
165
:param cache_dir: disk location to search for cached dataset
166
166
:return: the requested dataset
167
167
"""
168
- if self .data_args .dataset_path is not None :
169
- if self .data_args .dvc_data_repository is not None :
170
- self .data_args .raw_kwargs ["storage_options" ] = {
171
- "url" : self .data_args .dvc_data_repository
168
+ if self .dataset_args .dataset_path is not None :
169
+ if self .dataset_args .dvc_data_repository is not None :
170
+ self .dataset_args .raw_kwargs ["storage_options" ] = {
171
+ "url" : self .dataset_args .dvc_data_repository
172
172
}
173
- self .data_args .raw_kwargs ["data_files" ] = self .data_args .dataset_path
173
+ self .dataset_args .raw_kwargs ["data_files" ] = (
174
+ self .dataset_args .dataset_path
175
+ )
174
176
else :
175
- self .data_args .raw_kwargs ["data_files" ] = get_custom_datasets_from_path (
176
- self .data_args .dataset_path ,
177
- self .data_args .dataset
178
- if hasattr (self .data_args , "dataset" )
179
- else self .data_args .dataset_name ,
177
+ self .dataset_args .raw_kwargs ["data_files" ] = (
178
+ get_custom_datasets_from_path (
179
+ self .dataset_args .dataset_path ,
180
+ self .dataset_args .dataset
181
+ if hasattr (self .dataset_args , "dataset" )
182
+ else self .dataset_args .dataset_name ,
183
+ )
180
184
)
181
185
182
- logger .debug (f"Loading dataset { self .data_args .dataset } " )
186
+ logger .debug (f"Loading dataset { self .dataset_args .dataset } " )
183
187
return get_raw_dataset (
184
- self .data_args ,
188
+ self .dataset_args ,
185
189
None ,
186
190
split = self .split ,
187
- streaming = self .data_args .streaming ,
188
- ** self .data_args .raw_kwargs ,
191
+ streaming = self .dataset_args .streaming ,
192
+ ** self .dataset_args .raw_kwargs ,
189
193
)
190
194
191
195
@cached_property
@@ -194,7 +198,7 @@ def preprocess(self) -> Union[Callable[[LazyRow], Any], None]:
194
198
The function must return keys which correspond to processor/tokenizer kwargs,
195
199
optionally including PROMPT_KEY
196
200
"""
197
- preprocessing_func = self .data_args .preprocessing_func
201
+ preprocessing_func = self .dataset_args .preprocessing_func
198
202
199
203
if callable (preprocessing_func ):
200
204
return preprocessing_func
@@ -218,9 +222,9 @@ def dataset_template(self) -> Union[Callable[[Any], Any], None]:
218
222
def rename_columns (self , dataset : DatasetType ) -> DatasetType :
219
223
# rename columns to match processor/tokenizer kwargs
220
224
column_names = get_columns (dataset )
221
- if self .data_args .text_column in column_names and "text" not in column_names :
222
- logger .debug (f"Renaming column `{ self .data_args .text_column } ` to `text`" )
223
- dataset = dataset .rename_column (self .data_args .text_column , "text" )
225
+ if self .dataset_args .text_column in column_names and "text" not in column_names :
226
+ logger .debug (f"Renaming column `{ self .dataset_args .text_column } ` to `text`" )
227
+ dataset = dataset .rename_column (self .dataset_args .text_column , "text" )
224
228
225
229
return dataset
226
230
0 commit comments