Skip to content

Commit 26769ac

Browse files
committed
clean up CustomDataset
1 parent a47137d commit 26769ac

File tree

1 file changed

+46
-29
lines changed
  • src/llmcompressor/transformers/finetune/data

1 file changed

+46
-29
lines changed

src/llmcompressor/transformers/finetune/data/custom.py

Lines changed: 46 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -44,56 +44,73 @@ def __init__(self, data_args, split, tokenizer):
4444
split=split,
4545
tokenizer=tokenizer,
4646
)
47-
self.preprocessing_func = data_args.preprocessing_func
48-
self.remove_columns = data_args.remove_columns
4947

5048
def get_raw_dataset(self, *_ignore, **__ignore) -> Union[DatasetDict, Dataset]:
5149
"""Get the raw dataset and apply preprocessing func if provided"""
5250

53-
dataset = self.data_args.dataset
54-
if isinstance(dataset, DatasetDict) or isinstance(dataset, Dataset):
55-
# user passed in an already instantiated dataset, just use it directly
56-
raw_dataset = dataset
57-
else:
58-
# dataset must be loaded from file or HF Hub
59-
raw_dataset = super().get_raw_dataset()
60-
61-
if self.preprocessing_func is not None:
62-
if callable(self.preprocessing_func):
63-
func = self.preprocessing_func
64-
elif ":" in self.preprocessing_func:
51+
# load dataset
52+
dataset = (
53+
self.data_args.dataset
54+
if isinstance(self.data_args.dataset, (DatasetDict, Dataset))
55+
else super().get_raw_dataset() # load dataset from file or HF Hub
56+
)
57+
58+
# preprocess dataset
59+
dataset = self._preprocess_dataset(dataset)
60+
dataset = self._remove_columns_from_dataset(dataset)
61+
62+
return dataset
63+
64+
def _preprocess_dataset(
65+
self, dataset: Union[DatasetDict, Dataset]
66+
) -> Union[DatasetDict, Dataset]:
67+
preprocessing_func = self.data_args.preprocessing_func
68+
69+
if preprocessing_func is not None:
70+
if callable(preprocessing_func):
71+
pass
72+
73+
elif ":" in preprocessing_func:
6574
# load func_name from "/path/to/file.py:func_name"
66-
func = import_from_path(self.preprocessing_func)
75+
preprocessing_func = import_from_path(preprocessing_func)
6776
else:
6877
# load from the registry
69-
func = PreprocessingFunctionRegistry.get_value_from_registry(
70-
name=self.preprocessing_func
78+
preprocessing_func = (
79+
PreprocessingFunctionRegistry.get_value_from_registry(
80+
name=preprocessing_func
81+
)
7182
)
7283

73-
raw_dataset = self.map(
74-
raw_dataset,
75-
function=func,
84+
dataset = self.map(
85+
dataset,
86+
function=preprocessing_func,
7687
batched=False,
7788
num_proc=self.data_args.preprocessing_num_workers,
7889
desc="Applying custom func to the custom dataset",
7990
)
8091

81-
self.remove_columns = (
82-
self.remove_columns or self.get_remove_columns_from_dataset(raw_dataset)
83-
)
92+
return dataset
93+
94+
def _remove_columns_from_dataset(
95+
self, dataset: Union[DatasetDict, Dataset]
96+
) -> Union[DatasetDict, Dataset]:
97+
remove_columns = self.data_args.remove_columns
98+
99+
if not remove_columns:
100+
remove_columns = self._get_remove_columns_from_dataset(dataset)
84101

85-
if self.remove_columns is not None:
86-
raw_dataset = self.map(
87-
raw_dataset,
102+
if remove_columns is not None:
103+
dataset = self.map(
104+
dataset,
88105
batched=True,
89-
remove_columns=self.remove_columns,
106+
remove_columns=remove_columns,
90107
num_proc=self.data_args.preprocessing_num_workers,
91108
desc="Removing unneeded columns",
92109
)
93110

94-
return raw_dataset
111+
return dataset
95112

96-
def get_remove_columns_from_dataset(
113+
def _get_remove_columns_from_dataset(
97114
self, raw_dataset: Union[DatasetDict, Dataset]
98115
) -> List[str]:
99116
"""Remove redandant columns from the dataset for processing"""

0 commit comments

Comments
 (0)