Skip to content

[StageRunner] Stage Runner entrypoint and pipeline #1202

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Apr 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
sparsity_stage:
run_type: oneshot
sparsity_modifiers:
SparseGPTModifier:
sparsity: 0.5
mask_structure: "2:4"
targets: ["Linear"]
ignore: ["re:.*lm_head"]
finetuning_stage:
run_type: train
finetuning_modifiers:
ConstantPruningModifier:
targets: [
Expand All @@ -21,7 +19,6 @@ finetuning_stage:
]
start: 0
quantization_stage:
run_type: oneshot
quantization_modifiers:
GPTQModifier:
ignore: ["lm_head"]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
sparsity_stage:
run_type: oneshot
sparsity_modifiers:
SparseGPTModifier:
sparsity: 0.5
mask_structure: "2:4"
targets: ["Linear"]
ignore: ["re:.*lm_head"]
finetuning_stage:
run_type: train
finetuning_modifiers:
ConstantPruningModifier:
targets: [
Expand All @@ -21,7 +19,6 @@ finetuning_stage:
]
start: 0
quantization_stage:
run_type: oneshot
quantization_modifiers:
GPTQModifier:
ignore: ["lm_head"]
Expand Down
62 changes: 46 additions & 16 deletions examples/quantization_2of4_sparse_w4a16/llama7b_sparse_w4a16.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@
from loguru import logger
from transformers import AutoModelForCausalLM

from llmcompressor.transformers import apply

# define a recipe to handle sparsity, finetuning and quantization
recipe = "2of4_w4a16_recipe.yaml"
from llmcompressor import oneshot, train

# load the model in as bfloat16 to save on memory and compute
model_stub = "neuralmagic/Llama-2-7b-ultrachat200k"
Expand All @@ -16,6 +13,9 @@
# uses LLM Compressor's built-in preprocessing for ultra chat
dataset = "ultrachat-200k"

# Select the recipe for 2 of 4 sparsity and 4-bit activation quantization
recipe = "2of4_w4a16_recipe.yaml"

# save location of quantized model
output_dir = "output_llama7b_2of4_w4a16_channel"

Expand All @@ -33,31 +33,61 @@
bf16 = False # using full precision for training
lr_scheduler_type = "cosine"
warmup_ratio = 0.1
preprocessing_num_workers = 8
preprocessing_num_workers = 64

# this will run the recipe stage by stage:
# oneshot sparsification -> finetuning -> oneshot quantization
apply(
model=model,

oneshot_kwargs = dict(
dataset=dataset,
recipe=recipe,
bf16=bf16,
output_dir=output_dir,
num_calibration_samples=num_calibration_samples,
preprocessing_num_workers=preprocessing_num_workers,
splits=splits,
output_dir=output_dir,
)

training_kwargs = dict(
bf16=bf16,
max_seq_length=max_seq_length,
num_calibration_samples=num_calibration_samples,
num_train_epochs=num_train_epochs,
logging_steps=logging_steps,
save_steps=save_steps,
gradient_checkpointing=gradient_checkpointing,
learning_rate=learning_rate,
lr_scheduler_type=lr_scheduler_type,
warmup_ratio=warmup_ratio,
preprocessing_num_workers=preprocessing_num_workers,
)
logger.info(
"llmcompressor does not currently support running compressed models in the marlin24 format." # noqa

# This will run the targeted stage of the recipe
# oneshot sparsification -> finetuning -> oneshot quantization

# Models are automatically saved in
# ./output_llama7b_2of4_w4a16_channel/ + (finetuning/sparsity/quantization)_stage

# Oneshot sparsification
oneshot_applied_model = oneshot(
model=model,
**oneshot_kwargs,
stage="sparsity_stage",
)

# Sparse finetune
finetune_applied_model = train(
model=oneshot_applied_model,
**oneshot_kwargs,
**training_kwargs,
stage="finetuning_stage",
)

# Oneshot quantization
model = oneshot(
model=finetune_applied_model,
**oneshot_kwargs,
stage="quantization_stage",
)

logger.info(
"The model produced from this example can be run on vLLM with dtype=torch.float16"
"llmcompressor does not currently support running ",
"compressed models in the marlin24 format. "
"The model produced from this example can be ",
"run on vLLM with dtype=torch.float16.",
)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
"requests>=2.0.0",
"tqdm>=4.0.0",
"torch>=1.7.0",
"transformers>4.0,<5.0",
"transformers>4.0,<4.50",
"datasets",
"accelerate>=0.20.3,!=1.1.0",
"pynvml",
Expand Down
4 changes: 4 additions & 0 deletions src/llmcompressor/args/recipe_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,7 @@ class RecipeArguments:
)
},
)
stage: Optional[str] = field(
default=None,
metadata={"help": ("The stage of the recipe to use for oneshot / train.",)},
)
1 change: 0 additions & 1 deletion src/llmcompressor/core/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ def initialize(
:param kwargs: additional kwargs to pass to the lifecycle's initialize method
:return: the modified state of the session after initializing
"""

mod_data = self._lifecycle.initialize(
recipe=recipe,
recipe_stage=recipe_stage,
Expand Down
2 changes: 1 addition & 1 deletion src/llmcompressor/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

def get_processed_dataset(
dataset_args: DatasetArguments,
processor: Processor,
processor: Optional[Processor] = None,
do_oneshot: bool = False,
do_train: bool = True,
) -> Optional[Dict[str, Dataset]]:
Expand Down
46 changes: 10 additions & 36 deletions src/llmcompressor/entrypoints/oneshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,40 +106,14 @@ def __init__(
self.recipe_args = recipe_args
self.output_dir = output_dir

# initialize the model and processor
pre_process(model_args)

# Set instance attributes
self.model = self.model_args.model
self.processor = self.model_args.processor
self.recipe = self.recipe_args.recipe

@classmethod
def from_args(
cls,
model_args,
dataset_args,
recipe_args,
output_dir,
do_preprocess: bool = True,
):
"""
Used only for the stage runner to populate the args.
"""
instance = super().__new__(cls)
instance.model_args = model_args
instance.dataset_args = dataset_args
instance.recipe_args = recipe_args
instance.output_dir = output_dir

# only run for the first oneshot call
if do_preprocess:
pre_process(model_args)

# Set instance attributes
instance.model = instance.model_args.model
instance.recipe = instance.recipe_args.recipe
instance.processor = instance.model_args.processor

return instance

def __call__(self):
"""
Performs one-shot calibration.
Expand All @@ -150,20 +124,19 @@ def __call__(self):
postprocessing.

"""
# TODO: move back once stage runner is removed
# Preprocess the model and tokenizer/processor
pre_process(self.model_args)
self.model = self.model_args.model
self.recipe = self.recipe_args.recipe
self.processor = self.model_args.processor

calibration_dataloader = get_calibration_dataloader(
self.dataset_args, self.processor
)
self.apply_recipe_modifiers(
calibration_dataloader=calibration_dataloader,
recipe_stage=self.recipe_args.stage,
)
post_process(
model_args=self.model_args,
recipe_args=self.recipe_args,
output_dir=self.output_dir,
)
post_process(model_args=self.model_args, output_dir=self.output_dir)

def apply_recipe_modifiers(
self,
Expand Down Expand Up @@ -196,6 +169,7 @@ def apply_recipe_modifiers(
recipe_stage=recipe_stage,
)

session.reset()
session.initialize(**session_kwargs)
session.finalize(**session_kwargs)

Expand Down
26 changes: 24 additions & 2 deletions src/llmcompressor/entrypoints/train.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import math
import os

from loguru import logger
from transformers import PreTrainedModel

from llmcompressor.args import parse_args
from llmcompressor.core.session_functions import active_session
from llmcompressor.datasets.utils import get_processed_dataset
from llmcompressor.transformers.finetune.trainer import Trainer

from .utils import post_process, pre_process


def train(**kwargs):
def train(**kwargs) -> PreTrainedModel:
"""
Fine-tuning entrypoint that supports vanilla fine-tuning and
knowledge distillation for compressed model using `oneshot`.
Expand Down Expand Up @@ -67,6 +70,18 @@ def train(**kwargs):
)
training_dataset = processed_dataset.get("train")

# create output dir for stages
original_output_dir = output_dir = training_args.output_dir
if all([output_dir, recipe_args, getattr(recipe_args, "stage", None)]):
output_dir = os.path.join(original_output_dir, recipe_args.stage)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# update output dir in training args
logger.info(
f"Stage detected for training. Updating output dir to: {output_dir}"
)
training_args.output_dir = output_dir

trainer = Trainer(
model=model_args.model,
teacher=model_args.distill_teacher,
Expand All @@ -85,8 +100,12 @@ def train(**kwargs):
checkpoint = training_args.resume_from_checkpoint

logger.info("*** Train ***")

session = active_session()
session.reset()
train_result = trainer.train(
resume_from_checkpoint=checkpoint,
stage=recipe_args.stage,
)

# return output
Expand All @@ -99,4 +118,7 @@ def train(**kwargs):
# this includes saving the state, optimizer and scheduler
trainer.save_model(output_dir=training_args.output_dir)

post_process(model_args=model_args, output_dir=training_args.output_dir)
post_process(recipe_args=recipe_args)
training_args.output_dir = original_output_dir

return model_args.model
40 changes: 27 additions & 13 deletions src/llmcompressor/entrypoints/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
)
from transformers.utils.quantization_config import CompressedTensorsConfig

from llmcompressor.args import ModelArguments, TrainingArguments
from llmcompressor.args import ModelArguments, RecipeArguments, TrainingArguments
from llmcompressor.core import reset_session
from llmcompressor.pytorch.model_load.helpers import fallback_to_cpu, parse_dtype
from llmcompressor.transformers.sparsification.compressed_tensors_utils import (
modify_save_pretrained,
Expand Down Expand Up @@ -66,11 +67,15 @@ def pre_process(model_args: "ModelArguments"):


def post_process(
model_args: "ModelArguments",
model_args: Optional["ModelArguments"] = None,
recipe_args: Optional["RecipeArguments"] = None,
output_dir: Optional[str] = None,
):
"""
Saves the model and tokenizer/processor to the output directory.
Saves the model and tokenizer/processor to the output directory if model_args,
output_dir is provided.

Save is skipped for stage runs for `train` - saves using the trainer.save_model()

If the `output_dir` is not the default directory, the method resets lifecycle
actions. The model is saved in a compressed format if specified in `model_args`.
Expand All @@ -79,20 +84,29 @@ def post_process(
Raises:
ValueError: If saving fails due to an invalid `output_dir` or other issues.
"""
if output_dir is not None:
if model_args is not None and output_dir is not None:
if recipe_args is not None and getattr(recipe_args, "stage", None) is not None:
output_dir = os.path.join(output_dir, recipe_args.stage)
os.makedirs(output_dir, exist_ok=True)
logger.info(f"[Save] Stage detected. Updating output_dir to {output_dir}")

model_args.model.save_pretrained(
output_dir,
save_compressed=model_args.save_compressed,
output_dir, save_compressed=model_args.save_compressed
)
if model_args.processor:

if model_args.processor is not None:
model_args.processor.save_pretrained(output_dir)
return

logger.warning(
"Optimized model is not saved. To save, please provide",
"`output_dir` as input arg.",
"Ex. `oneshot(..., output_dir=...)`",
)
else:
logger.warning(
"Optimized model is not saved. To save, please provide",
"`output_dir` as input arg.",
"Ex. `oneshot(..., output_dir=...)`",
)

# Reset the one-time-use session upon completion
if recipe_args is not None and recipe_args.clear_sparse_session:
reset_session()


def _warn_tied_embeddings(tie_word_embeddings: bool = False):
Expand Down
3 changes: 1 addition & 2 deletions src/llmcompressor/recipe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
)
from .modifier import RecipeModifier
from .recipe import Recipe, RecipeArgsInput, RecipeInput, RecipeStageInput, RecipeTuple
from .stage import RecipeStage, StageRunType
from .stage import RecipeStage

__all__ = [
"DatasetMetaData",
Expand All @@ -25,7 +25,6 @@
"RecipeArgs",
"Recipe",
"RecipeTuple",
"StageRunType",
"RecipeInput",
"RecipeStageInput",
"RecipeArgsInput",
Expand Down
Loading