@@ -44,56 +44,73 @@ def __init__(self, data_args, split, tokenizer):
44
44
split = split ,
45
45
tokenizer = tokenizer ,
46
46
)
47
- self .preprocessing_func = data_args .preprocessing_func
48
- self .remove_columns = data_args .remove_columns
49
47
50
48
def get_raw_dataset (self , * _ignore , ** __ignore ) -> Union [DatasetDict , Dataset ]:
51
49
"""Get the raw dataset and apply preprocessing func if provided"""
52
50
53
- dataset = self .data_args .dataset
54
- if isinstance (dataset , DatasetDict ) or isinstance (dataset , Dataset ):
55
- # user passed in an already instantiated dataset, just use it directly
56
- raw_dataset = dataset
57
- else :
58
- # dataset must be loaded from file or HF Hub
59
- raw_dataset = super ().get_raw_dataset ()
60
-
61
- if self .preprocessing_func is not None :
62
- if callable (self .preprocessing_func ):
63
- func = self .preprocessing_func
64
- elif ":" in self .preprocessing_func :
51
+ # load dataset
52
+ dataset = (
53
+ self .data_args .dataset
54
+ if isinstance (self .data_args .dataset , (DatasetDict , Dataset ))
55
+ else super ().get_raw_dataset () # load dataset from file or HF Hub
56
+ )
57
+
58
+ # preprocess dataset
59
+ dataset = self ._preprocess_dataset (dataset )
60
+ dataset = self ._remove_columns_from_dataset (dataset )
61
+
62
+ return dataset
63
+
64
+ def _preprocess_dataset (
65
+ self , dataset : Union [DatasetDict , Dataset ]
66
+ ) -> Union [DatasetDict , Dataset ]:
67
+ preprocessing_func = self .data_args .preprocessing_func
68
+
69
+ if preprocessing_func is not None :
70
+ if callable (preprocessing_func ):
71
+ pass
72
+
73
+ elif ":" in preprocessing_func :
65
74
# load func_name from "/path/to/file.py:func_name"
66
- func = import_from_path (self . preprocessing_func )
75
+ preprocessing_func = import_from_path (preprocessing_func )
67
76
else :
68
77
# load from the registry
69
- func = PreprocessingFunctionRegistry .get_value_from_registry (
70
- name = self .preprocessing_func
78
+ preprocessing_func = (
79
+ PreprocessingFunctionRegistry .get_value_from_registry (
80
+ name = preprocessing_func
81
+ )
71
82
)
72
83
73
- raw_dataset = self .map (
74
- raw_dataset ,
75
- function = func ,
84
+ dataset = self .map (
85
+ dataset ,
86
+ function = preprocessing_func ,
76
87
batched = False ,
77
88
num_proc = self .data_args .preprocessing_num_workers ,
78
89
desc = "Applying custom func to the custom dataset" ,
79
90
)
80
91
81
- self .remove_columns = (
82
- self .remove_columns or self .get_remove_columns_from_dataset (raw_dataset )
83
- )
92
+ return dataset
93
+
94
+ def _remove_columns_from_dataset (
95
+ self , dataset : Union [DatasetDict , Dataset ]
96
+ ) -> Union [DatasetDict , Dataset ]:
97
+ remove_columns = self .data_args .remove_columns
98
+
99
+ if not remove_columns :
100
+ remove_columns = self ._get_remove_columns_from_dataset (dataset )
84
101
85
- if self . remove_columns is not None :
86
- raw_dataset = self .map (
87
- raw_dataset ,
102
+ if remove_columns is not None :
103
+ dataset = self .map (
104
+ dataset ,
88
105
batched = True ,
89
- remove_columns = self . remove_columns ,
106
+ remove_columns = remove_columns ,
90
107
num_proc = self .data_args .preprocessing_num_workers ,
91
108
desc = "Removing unneeded columns" ,
92
109
)
93
110
94
- return raw_dataset
111
+ return dataset
95
112
96
- def get_remove_columns_from_dataset (
113
+ def _get_remove_columns_from_dataset (
97
114
self , raw_dataset : Union [DatasetDict , Dataset ]
98
115
) -> List [str ]:
99
116
"""Remove redandant columns from the dataset for processing"""
0 commit comments