Skip to content

Commit b06c4cf

Browse files
authored
[StageRemoval] Remove Predict pathway (#1146)
SUMMARY: * Remove predict pathway from `main` * Remove predict pathway from `StageRunner` * Remove logic to make predict dataset split * Remove `max_predict_samples` from `DatasetArguments` * Remove any docs/comment that has `predict` inside * Rename `predicted_ids` to `output_ids` TEST PLAN: * Pass existing tests * Delete tests involving predict --------- Signed-off-by: George <[email protected]>
1 parent ffd3ef9 commit b06c4cf

File tree

8 files changed

+18
-92
lines changed

8 files changed

+18
-92
lines changed

examples/quantization_w8a8_fp8/whisper_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@
3535
sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt"
3636
).input_features
3737
input_features = input_features.to(model.device)
38-
predicted_ids = model.generate(input_features, language="en", forced_decoder_ids=None)
39-
print(processor.batch_decode(predicted_ids, skip_special_tokens=False)[0])
38+
output_ids = model.generate(input_features, language="en", forced_decoder_ids=None)
39+
print(processor.batch_decode(output_ids, skip_special_tokens=False)[0])
4040
# Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel
4141
print("==========================================")
4242

src/llmcompressor/args/dataset_arguments.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -150,15 +150,6 @@ class DatasetArguments(CustomDatasetArguments):
150150
"of training examples to this value if set."
151151
},
152152
)
153-
max_predict_samples: Optional[int] = field(
154-
default=None,
155-
metadata={
156-
"help": (
157-
"For debugging purposes or quicker training, truncate the number of "
158-
"prediction examples to this value if set."
159-
),
160-
},
161-
)
162153
min_tokens_per_module: Optional[float] = field(
163154
default=None,
164155
metadata={

src/llmcompressor/args/training_arguments.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ class TrainingArguments(HFTrainingArgs):
2626
output_dir: str = field(
2727
default="./output",
2828
metadata={
29-
"help": "The output directory where the model predictions and "
30-
"checkpoints will be written."
29+
"help": "The output directory where the model safetensors, "
30+
"recipe, config, and optionally checkpoints will be written."
3131
},
3232
)
3333

src/llmcompressor/transformers/finetune/data/data_helpers.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -97,17 +97,15 @@ def get_raw_dataset(
9797
def make_dataset_splits(
9898
tokenized_datasets: Dict[str, Any],
9999
do_train: bool = False,
100-
do_predict: bool = False,
101100
do_oneshot: bool = False,
102101
) -> Dict[str, Dataset]:
103102
"""
104103
Restructures the datasets dictionary based on what tasks will be run
105-
(train, predict)
104+
train
106105
107106
:param tokenized_datasets: dictionary of processed datasets
108-
:param do_train: Whether to store the train dataset
109-
:param do_predict: Whether to store the test dataset
110107
:param do_oneshot: Whether to store the calibration dataset
108+
111109
:return: Datasets to be used by the requested tasks
112110
"""
113111

@@ -117,16 +115,12 @@ def make_dataset_splits(
117115
if isinstance(tokenized_datasets, Dataset):
118116
tokenized_datasets = {"train": tokenized_datasets}
119117

120-
train_split = predict_split = calib_split = None
118+
train_split = calib_split = None
121119

122120
if do_train:
123121
if "train" not in tokenized_datasets:
124122
raise ValueError("--do_train requires a train dataset")
125123
train_split = tokenized_datasets["train"]
126-
if do_predict:
127-
if "test" not in tokenized_datasets:
128-
raise ValueError("--do_predict requires a test dataset")
129-
predict_split = tokenized_datasets["test"]
130124
if do_oneshot:
131125
calib_split = tokenized_datasets.get("calibration")
132126
if calib_split is None:
@@ -136,7 +130,6 @@ def make_dataset_splits(
136130

137131
split_datasets = {
138132
"train": train_split,
139-
"test": predict_split,
140133
"calibration": calib_split,
141134
}
142135
return split_datasets

src/llmcompressor/transformers/finetune/runner.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class StageRunner:
3838
LifeCycle
3939
- populate_datasets()
4040
- set_trainer()
41-
- train() / predict()
41+
- train()
4242
4343
:param model_args: Arguments pertaining to model/config/processor
4444
:param data_args: Arguments pertaining to what data to use for different flows
@@ -121,7 +121,6 @@ def _get_split_name(inp_str):
121121
self.datasets = make_dataset_splits(
122122
tokenized_datasets,
123123
do_train=self._training_args.do_train,
124-
do_predict=self._training_args.do_predict,
125124
do_oneshot=self._training_args.do_oneshot,
126125
)
127126

@@ -155,18 +154,6 @@ def train(self, checkpoint: str, stage: Optional[str] = None):
155154
# this includes saving the state, optimizer and scheduler
156155
self.trainer.save_model(output_dir=self._output_dir)
157156

158-
def predict(self):
159-
"""
160-
Run trainer's prediction loop on predict_dataset, logging the desired metrics
161-
"""
162-
logger.info("*** Predict ***")
163-
results = self.trainer.predict(self.dataset["test"])
164-
metrics = results.metrics
165-
166-
metrics["predict_samples"] = len(self.dataset["test"])
167-
self.trainer.log_metrics("predict", metrics)
168-
self.trainer.save_metrics("predict", metrics)
169-
170157
def run_sequential_stages(self, checkpoint: Optional[str] = None):
171158
"""
172159
Run the recipe stage by stage, allowing for alternating between one-shot and

src/llmcompressor/transformers/finetune/session_mixin.py

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -344,31 +344,6 @@ def compute_loss(
344344

345345
return loss
346346

347-
def prediction_step(
348-
self,
349-
model: Module,
350-
inputs: Dict[str, Union[torch.Tensor, Any]],
351-
prediction_loss_only: bool,
352-
ignore_keys: Optional[List[str]] = None,
353-
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
354-
"""
355-
Wraps the prediction step from the original trainer to remove any input entry
356-
that should not be passed to the model.
357-
This situation may arise when distillation is used and the teacher model
358-
contains more inputs than the student model.
359-
"""
360-
self._check_super_defined("prediction_step")
361-
362-
inputs = {k: inputs[k] for k in inputs if k in self._model_signature_columns}
363-
364-
model_outputs = super().prediction_step(
365-
model=model,
366-
inputs=inputs,
367-
prediction_loss_only=prediction_loss_only,
368-
ignore_keys=ignore_keys,
369-
)
370-
return model_outputs
371-
372347
def train(self, *args, stage: Optional[str] = None, **kwargs):
373348
"""
374349
Run a sparsification training cycle. Runs initialization for the sparse session
@@ -408,22 +383,6 @@ def train(self, *args, stage: Optional[str] = None, **kwargs):
408383

409384
return output
410385

411-
def predict(self, *args, **kwargs):
412-
"""
413-
Run a sparsification prediction cycle.
414-
Runs initialize_structure for the sparse session before calling
415-
super().predict() and finalization of the session after.
416-
417-
:param args: positional args to pass to super().predict()
418-
:param kwargs: keyword args to pass to super().predict()
419-
:return: the output from super.predict()
420-
"""
421-
self.initialize_structure()
422-
output = super().predict(*args, **kwargs)
423-
self.finalize_session()
424-
425-
return output
426-
427386
def save_model(self, output_dir: str, _internal_call=False, _is_oneshot=False):
428387
"""
429388
Override of the save_model function and expects it to exist in the parent.

src/llmcompressor/transformers/finetune/text_generation.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,13 @@ def eval(**kwargs):
9292
)
9393
def oneshot(**kwargs) -> None:
9494
from llmcompressor import oneshot
95+
9596
oneshot(**kwargs)
9697

9798

9899
def apply(**kwargs):
99100
"""
100-
CLI entrypoint for any of training, predict or oneshot
101+
CLI entrypoint for any of training, oneshot
101102
"""
102103
report_to = kwargs.get("report_to", None)
103104
model_args, data_args, recipe_args, training_args = parse_args(**kwargs)
@@ -322,7 +323,8 @@ def main(
322323
- Trainer()
323324
- SessionMixIn()
324325
- HFTransformersTrainer()
325-
- StageRunner.train() and/or predict() and/or oneshot()
326+
- StageRunner.train() and/or oneshot()
327+
326328
327329
:param model_args: Arguments pertaining to which model/config/tokenizer we are
328330
going to fine-tune from
@@ -437,10 +439,6 @@ def main(
437439
checkpoint = last_checkpoint
438440
stage_runner.train(checkpoint)
439441

440-
# Prediction
441-
if training_args.do_predict:
442-
stage_runner.predict()
443-
444442
# save if model was provided as a string or custom output_dir was set
445443

446444
if isinstance(model_args.model, str) or (

tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,11 @@ def test_combined_datasets():
1414
)
1515
raw_wikitext2 = get_raw_dataset(data_args)
1616
datasets = {"all": raw_wikitext2}
17-
18-
split_datasets = make_dataset_splits(datasets, do_train=True, do_predict=True)
17+
split_datasets = make_dataset_splits(datasets, do_train=True)
1918
assert split_datasets.get("train") is not None
20-
assert split_datasets.get("test") is not None
2119

22-
split_datasets = make_dataset_splits(datasets, do_train=True, do_predict=True)
20+
split_datasets = make_dataset_splits(datasets, do_train=True)
2321
assert split_datasets.get("train") is not None
24-
assert split_datasets.get("test") is not None
2522

2623

2724
@pytest.mark.unit
@@ -35,10 +32,11 @@ def test_separate_datasets():
3532
raw_wikitext2 = get_raw_dataset(data_args, split=split_str)
3633
datasets[split_name] = raw_wikitext2
3734

38-
split_datasets = make_dataset_splits(datasets, do_train=True, do_predict=False)
35+
split_datasets = make_dataset_splits(datasets, do_train=True)
3936
assert split_datasets.get("train") is not None
40-
assert split_datasets.get("test") is None
4137

4238
with pytest.raises(ValueError):
4339
# fails due to no test split specified
44-
split_datasets = make_dataset_splits(datasets, do_train=True, do_predict=True)
40+
41+
datasets.pop("train")
42+
split_datasets = make_dataset_splits(datasets, do_train=True)

0 commit comments

Comments
 (0)