Skip to content

Commit 391b202

Browse files
horheynmdsikka
andauthored
[Cosmetic] Rename data_args to dataset_args (#1206)
Order of reviews: #1206 <-- Here #1207 #1209 #1212 #1214 SUMMARY: Rename data_args to dataset_args TEST PLAN: Pass tests FInd `data_args` using `grep` --------- Signed-off-by: George Ohashi <[email protected]> Co-authored-by: Dipika Sikka <[email protected]>
1 parent 07726ef commit 391b202

25 files changed

+256
-226
lines changed

examples/trl_mixin/ex_trl_distillation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@
1919
max_seq_length = 512
2020

2121
# Load gsm8k using SparseML dataset tools
22-
data_args = DatasetArguments(
22+
dataset_args = DatasetArguments(
2323
dataset="gsm8k", dataset_config_name="main", max_seq_length=max_seq_length
2424
)
2525
dataset_manager = TextGenerationDataset.load_from_registry(
26-
data_args.dataset,
27-
data_args=data_args,
26+
dataset_args.dataset,
27+
dataset_args=dataset_args,
2828
split="train",
2929
processor=tokenizer,
3030
)
@@ -69,7 +69,7 @@
6969
train_dataset=train_dataset,
7070
data_collator=data_collator,
7171
trl_sft_config_args=trl_sft_config_args,
72-
data_args=data_args,
72+
dataset_args=dataset_args,
7373
model_args=model_args,
7474
)
7575
trainer.train()

src/llmcompressor/entrypoints/oneshot.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class Oneshot:
3535
`kwargs` are parsed into:
3636
- `model_args`: Arguments for loading and configuring a pretrained model
3737
(e.g., `AutoModelForCausalLM`).
38-
- `data_args`: Arguments for dataset-related configurations, such as
38+
- `dataset_args`: Arguments for dataset-related configurations, such as
3939
calibration dataloaders.
4040
- `recipe_args`: Arguments for defining and configuring recipes that specify
4141
optimization actions.
@@ -108,24 +108,23 @@ def __init__(
108108
"""
109109
Initializes the `Oneshot` class with provided arguments.
110110
111-
Parses the input keyword arguments into `model_args`, `data_args`, and
111+
Parses the input keyword arguments into `model_args`, `dataset_args`, and
112112
`recipe_args`. Performs preprocessing to initialize the model and
113113
tokenizer/processor.
114114
115115
:param model_args: ModelArguments parameters, responsible for controlling
116116
model loading and saving logic
117-
:param data_args: DatasetArguments parameters, responsible for controlling
117+
:param dataset_args: DatasetArguments parameters, responsible for controlling
118118
dataset loading, preprocessing and dataloader loading
119119
:param recipe_args: RecipeArguments parameters, responsible for containing
120120
recipe-related parameters
121121
:param output_dir: Path to save the output model after carrying out oneshot
122122
123123
"""
124-
125124
model_args, dataset_args, recipe_args, _, output_dir = parse_args(**kwargs)
126125

127126
self.model_args = model_args
128-
self.data_args = dataset_args
127+
self.dataset_args = dataset_args
129128
self.recipe_args = recipe_args
130129
self.output_dir = output_dir
131130

@@ -136,14 +135,19 @@ def __init__(
136135

137136
@classmethod
138137
def from_args(
139-
cls, model_args, data_args, recipe_args, output_dir, do_preprocess: bool = True
138+
cls,
139+
model_args,
140+
dataset_args,
141+
recipe_args,
142+
output_dir,
143+
do_preprocess: bool = True,
140144
):
141145
"""
142146
Used only for the stage runner to populate the args.
143147
"""
144148
instance = super().__new__(cls)
145149
instance.model_args = model_args
146-
instance.data_args = data_args
150+
instance.dataset_args = dataset_args
147151
instance.recipe_args = recipe_args
148152
instance.output_dir = output_dir
149153

@@ -176,7 +180,7 @@ def __call__(self):
176180
self.processor = self.model_args.processor
177181

178182
calibration_dataloader = get_calibration_dataloader(
179-
self.data_args, self.processor
183+
self.dataset_args, self.processor
180184
)
181185
self.apply_recipe_modifiers(
182186
calibration_dataloader=calibration_dataloader,
@@ -242,7 +246,7 @@ def _pre_process(self):
242246
- Applies patches to fix tied tensor issues and modifies `save_pretrained`
243247
behavior.
244248
- Initializes the processor if specified as a path or `None`.
245-
- Sets the minimum tokens per module if `data_args` are provided.
249+
- Sets the minimum tokens per module if `dataset_args` are provided.
246250
247251
Raises:
248252
FileNotFoundError: If the model or processor path is invalid.
@@ -265,8 +269,8 @@ def _pre_process(self):
265269
self.processor = self.model_args.processor
266270

267271
# Set minimum tokens per module if data arguments are provided
268-
if self.data_args:
269-
self.min_tokens_per_module = self.data_args.min_tokens_per_module
272+
if self.dataset_args:
273+
self.min_tokens_per_module = self.dataset_args.min_tokens_per_module
270274

271275
def check_tied_embeddings(self):
272276
"""

src/llmcompressor/transformers/finetune/data/base.py

Lines changed: 40 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class TextGenerationDataset(RegistryMixin):
3131
3. Tokenize dataset using model tokenizer/processor
3232
4. Apply post processing such as grouping text and/or adding labels for finetuning
3333
34-
:param data_args: configuration settings for dataset loading
34+
:param dataset_args: configuration settings for dataset loading
3535
:param split: split from dataset to load, for instance `test` or `train[:5%]`
3636
:param processor: processor or tokenizer to use on dataset
3737
"""
@@ -41,11 +41,11 @@ class TextGenerationDataset(RegistryMixin):
4141

4242
def __init__(
4343
self,
44-
data_args: DatasetArguments,
44+
dataset_args: DatasetArguments,
4545
split: str,
4646
processor: Processor,
4747
):
48-
self.data_args = data_args
48+
self.dataset_args = dataset_args
4949
self.split = split
5050
self.processor = processor
5151

@@ -58,23 +58,23 @@ def __init__(
5858
self.tokenizer.pad_token = self.tokenizer.eos_token
5959

6060
# configure sequence length
61-
max_seq_length = data_args.max_seq_length
62-
if data_args.max_seq_length > self.tokenizer.model_max_length:
61+
max_seq_length = dataset_args.max_seq_length
62+
if dataset_args.max_seq_length > self.tokenizer.model_max_length:
6363
logger.warning(
6464
f"The max_seq_length passed ({max_seq_length}) is larger than "
6565
f"maximum length for model ({self.tokenizer.model_max_length}). "
6666
f"Using max_seq_length={self.tokenizer.model_max_length}."
6767
)
6868
self.max_seq_length = min(
69-
data_args.max_seq_length, self.tokenizer.model_max_length
69+
dataset_args.max_seq_length, self.tokenizer.model_max_length
7070
)
7171

7272
# configure padding
7373
self.padding = (
7474
False
75-
if self.data_args.concatenate_data
75+
if self.dataset_args.concatenate_data
7676
else "max_length"
77-
if self.data_args.pad_to_max_length
77+
if self.dataset_args.pad_to_max_length
7878
else False
7979
)
8080

@@ -83,7 +83,7 @@ def __init__(
8383
self.padding = False
8484

8585
def __call__(self, add_labels: bool = True) -> DatasetType:
86-
dataset = self.data_args.dataset
86+
dataset = self.dataset_args.dataset
8787

8888
if isinstance(dataset, str):
8989
# load dataset: load from huggingface or disk
@@ -96,8 +96,8 @@ def __call__(self, add_labels: bool = True) -> DatasetType:
9696
dataset,
9797
self.preprocess,
9898
batched=False,
99-
num_proc=self.data_args.preprocessing_num_workers,
100-
load_from_cache_file=not self.data_args.overwrite_cache,
99+
num_proc=self.dataset_args.preprocessing_num_workers,
100+
load_from_cache_file=not self.dataset_args.overwrite_cache,
101101
desc="Preprocessing",
102102
)
103103
logger.debug(f"Dataset after preprocessing: {get_columns(dataset)}")
@@ -121,20 +121,20 @@ def __call__(self, add_labels: bool = True) -> DatasetType:
121121
# regardless of `batched` argument
122122
remove_columns=get_columns(dataset), # assumes that input names
123123
# and output names are disjoint
124-
num_proc=self.data_args.preprocessing_num_workers,
125-
load_from_cache_file=not self.data_args.overwrite_cache,
124+
num_proc=self.dataset_args.preprocessing_num_workers,
125+
load_from_cache_file=not self.dataset_args.overwrite_cache,
126126
desc="Tokenizing",
127127
)
128128
logger.debug(f"Model kwargs after tokenizing: {get_columns(dataset)}")
129129

130-
if self.data_args.concatenate_data:
130+
if self.dataset_args.concatenate_data:
131131
# postprocess: group text
132132
dataset = self.map(
133133
dataset,
134134
self.group_text,
135135
batched=True,
136-
num_proc=self.data_args.preprocessing_num_workers,
137-
load_from_cache_file=not self.data_args.overwrite_cache,
136+
num_proc=self.dataset_args.preprocessing_num_workers,
137+
load_from_cache_file=not self.dataset_args.overwrite_cache,
138138
desc="Concatenating data",
139139
)
140140
logger.debug(f"Model kwargs after concatenating: {get_columns(dataset)}")
@@ -145,8 +145,8 @@ def __call__(self, add_labels: bool = True) -> DatasetType:
145145
dataset,
146146
self.add_labels,
147147
batched=False, # not compatible with batching, need row lengths
148-
num_proc=self.data_args.preprocessing_num_workers,
149-
load_from_cache_file=not self.data_args.overwrite_cache,
148+
num_proc=self.dataset_args.preprocessing_num_workers,
149+
load_from_cache_file=not self.dataset_args.overwrite_cache,
150150
desc="Adding labels",
151151
)
152152
logger.debug(f"Model kwargs after adding labels: {get_columns(dataset)}")
@@ -165,27 +165,31 @@ def load_dataset(self):
165165
:param cache_dir: disk location to search for cached dataset
166166
:return: the requested dataset
167167
"""
168-
if self.data_args.dataset_path is not None:
169-
if self.data_args.dvc_data_repository is not None:
170-
self.data_args.raw_kwargs["storage_options"] = {
171-
"url": self.data_args.dvc_data_repository
168+
if self.dataset_args.dataset_path is not None:
169+
if self.dataset_args.dvc_data_repository is not None:
170+
self.dataset_args.raw_kwargs["storage_options"] = {
171+
"url": self.dataset_args.dvc_data_repository
172172
}
173-
self.data_args.raw_kwargs["data_files"] = self.data_args.dataset_path
173+
self.dataset_args.raw_kwargs["data_files"] = (
174+
self.dataset_args.dataset_path
175+
)
174176
else:
175-
self.data_args.raw_kwargs["data_files"] = get_custom_datasets_from_path(
176-
self.data_args.dataset_path,
177-
self.data_args.dataset
178-
if hasattr(self.data_args, "dataset")
179-
else self.data_args.dataset_name,
177+
self.dataset_args.raw_kwargs["data_files"] = (
178+
get_custom_datasets_from_path(
179+
self.dataset_args.dataset_path,
180+
self.dataset_args.dataset
181+
if hasattr(self.dataset_args, "dataset")
182+
else self.dataset_args.dataset_name,
183+
)
180184
)
181185

182-
logger.debug(f"Loading dataset {self.data_args.dataset}")
186+
logger.debug(f"Loading dataset {self.dataset_args.dataset}")
183187
return get_raw_dataset(
184-
self.data_args,
188+
self.dataset_args,
185189
None,
186190
split=self.split,
187-
streaming=self.data_args.streaming,
188-
**self.data_args.raw_kwargs,
191+
streaming=self.dataset_args.streaming,
192+
**self.dataset_args.raw_kwargs,
189193
)
190194

191195
@cached_property
@@ -194,7 +198,7 @@ def preprocess(self) -> Union[Callable[[LazyRow], Any], None]:
194198
The function must return keys which correspond to processor/tokenizer kwargs,
195199
optionally including PROMPT_KEY
196200
"""
197-
preprocessing_func = self.data_args.preprocessing_func
201+
preprocessing_func = self.dataset_args.preprocessing_func
198202

199203
if callable(preprocessing_func):
200204
return preprocessing_func
@@ -218,9 +222,9 @@ def dataset_template(self) -> Union[Callable[[Any], Any], None]:
218222
def rename_columns(self, dataset: DatasetType) -> DatasetType:
219223
# rename columns to match processor/tokenizer kwargs
220224
column_names = get_columns(dataset)
221-
if self.data_args.text_column in column_names and "text" not in column_names:
222-
logger.debug(f"Renaming column `{self.data_args.text_column}` to `text`")
223-
dataset = dataset.rename_column(self.data_args.text_column, "text")
225+
if self.dataset_args.text_column in column_names and "text" not in column_names:
226+
logger.debug(f"Renaming column `{self.dataset_args.text_column}` to `text`")
227+
dataset = dataset.rename_column(self.dataset_args.text_column, "text")
224228

225229
return dataset
226230

src/llmcompressor/transformers/finetune/data/c4.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,16 @@ class C4Dataset(TextGenerationDataset):
1313
"""
1414
Child text generation class for the C4 dataset
1515
16-
:param data_args: configuration settings for dataset loading
16+
:param dataset_args: configuration settings for dataset loading
1717
:param split: split from dataset to load, for instance `test` or `train[:5%]`
1818
:param processor: processor or tokenizer to use on dataset
1919
"""
2020

21-
def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
22-
data_args = deepcopy(data_args)
23-
data_args.dataset = "allenai/c4"
24-
data_args.text_column = "text"
21+
def __init__(
22+
self, dataset_args: "DatasetArguments", split: str, processor: Processor
23+
):
24+
dataset_args = deepcopy(dataset_args)
25+
dataset_args.dataset = "allenai/c4"
26+
dataset_args.text_column = "text"
2527

26-
super().__init__(data_args=data_args, split=split, processor=processor)
28+
super().__init__(dataset_args=dataset_args, split=split, processor=processor)

src/llmcompressor/transformers/finetune/data/cnn_dailymail.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,21 @@ class CNNDailyMailDataset(TextGenerationDataset):
1313
"""
1414
Text generation class for the CNN/DailyMail dataset
1515
16-
:param data_args: configuration settings for dataset loading
16+
:param dataset_args: configuration settings for dataset loading
1717
:param split: split from dataset to load, for instance `test` or `train[:5%]`
1818
:param processor: processor or tokenizer to use on dataset
1919
"""
2020

2121
SAMPLE_TEMPLATE = "Article:\n{article}\n\n### Summarization:\n{highlights}\n"
2222

23-
def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
24-
data_args = deepcopy(data_args)
25-
data_args.dataset = "cnn_dailymail"
26-
data_args.dataset_config_name = "3.0.0"
23+
def __init__(
24+
self, dataset_args: "DatasetArguments", split: str, processor: Processor
25+
):
26+
dataset_args = deepcopy(dataset_args)
27+
dataset_args.dataset = "cnn_dailymail"
28+
dataset_args.dataset_config_name = "3.0.0"
2729

28-
super().__init__(data_args=data_args, split=split, processor=processor)
30+
super().__init__(dataset_args=dataset_args, split=split, processor=processor)
2931

3032
def dataset_template(self, sample):
3133
return {

src/llmcompressor/transformers/finetune/data/custom.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ class CustomDataset(TextGenerationDataset):
77
Child text generation class for custom local dataset supporting load
88
for csv and json
99
10-
:param data_args: configuration settings for dataset loading
10+
:param dataset_args: configuration settings for dataset loading
1111
:param split: split from dataset to load, for instance `test` or `train[:5%]`
1212
Can also be set to None to load all the splits
1313
:param processor: processor or tokenizer to use on dataset

0 commit comments

Comments
 (0)