Skip to content

[StageRunner] Stage Runner entrypoint and pipeline #1202

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Apr 1, 2025
Merged

[StageRunner] Stage Runner entrypoint and pipeline #1202

merged 27 commits into from
Apr 1, 2025

Conversation

horheynm
Copy link
Contributor

@horheynm horheynm commented Feb 26, 2025

SUMMARY:

  • Remove from_args and reorg logic in Oneshot. Previously used to be compatible with class StageRunner, which is removed.

  • Remove StageRunType and its logic from StageRunner and related files.

  • Remove class StageRuneer and its file.

  • Remove stage runner logic from transformers/text_geneneration.py

  • Remove tests/llmcompressor/entrypoints/test_oneshot.py, which is a test for Oneshot.from_args, which is removed.

  • Remove tests/llmcompressor/recipe/test_stage.py which is a test to select the stage and run_type.

  • Add logic in train to return a PretrainedModel, and also to change output_dir if stage is passed in to oneshot or train. If stage is passed in the new directory changes from ./out -> ./out/{stage}.

  • Add stage in RecipeArguments

  • Modify saving logic in trainer. Use self.trainer.save instead of post_process save. post_process will still be called, but will only reset session.

  • Modify logic in post_process to save + reset or only reset, if no model_args or output_dir is passed in. (need model_args for model, output_dir for save dir)

TEST PLAN:
Pass tests
Check examples/quantization_2of4_sparse_w4a16/llama7b_sparse_w4a16.py runs and generates the same output as main

SCRIPT:

from transformers import AutoModelForCausalLM

from llmcompressor import oneshot, train

recipe = r"""
sparsity_stage:
    sparsity_modifiers:
        SparseGPTModifier:
            sparsity: 0.5
            mask_structure: "2:4"
            targets: ["Linear"]
            ignore: ["re:.*lm_head"]
finetuning_stage:
    finetuning_modifiers:
        ConstantPruningModifier:
            targets: [
                're:.*q_proj.weight',
                're:.*k_proj.weight', 
                're:.*v_proj.weight',
                're:.*o_proj.weight',
                're:.*gate_proj.weight',
                're:.*up_proj.weight',
                're:.*down_proj.weight',
            ]
            start: 0
quantization_stage:
    quantization_modifiers:
        GPTQModifier:
            ignore: ["lm_head"]
            config_groups:
                group_0:
                    weights:
                        num_bits: 4
                        type: "int"
                        symmetric: true
                        strategy: "channel"
                    targets: ["Linear"]
            
"""


import torch
from loguru import logger
from transformers import AutoModelForCausalLM

from llmcompressor import oneshot, train

# load the model in as bfloat16 to save on memory and compute
model_stub = "neuralmagic/Llama-2-7b-ultrachat200k"
model = AutoModelForCausalLM.from_pretrained(
    model_stub, torch_dtype=torch.bfloat16, device_map="auto"
)

# uses LLM Compressor's built-in preprocessing for ultra chat
dataset = "ultrachat-200k"

# save location of quantized model
output_dir = "output_llama7b_2of4_w4a16_channel"

# set dataset config parameters
splits = {"calibration": "train_gen[:5%]", "train": "train_gen"}
max_seq_length = 512
num_calibration_samples = 512

# set training parameters for finetuning
num_train_epochs = 0.01
logging_steps = 500
save_steps = 5000
gradient_checkpointing = True  # saves memory during training
learning_rate = 0.0001
bf16 = False  # using full precision for training
lr_scheduler_type = "cosine"
warmup_ratio = 0.1
preprocessing_num_workers = 64 * 6

# this will run the recipe stage by stage:
# oneshot sparsification -> finetuning -> oneshot quantization

oneshot_kwargs = dict(
    dataset=dataset,
    recipe=recipe,
    num_calibration_samples=num_calibration_samples,
    preprocessing_num_workers=preprocessing_num_workers,
    splits=splits,
    output_dir=output_dir
)

training_kwargs = dict(
    bf16=bf16,
    max_seq_length=max_seq_length,
    num_train_epochs=num_train_epochs,
    logging_steps=logging_steps,
    save_steps=save_steps,
    gradient_checkpointing=gradient_checkpointing,
    learning_rate=learning_rate,
    lr_scheduler_type=lr_scheduler_type,
    warmup_ratio=warmup_ratio,
)

oneshot_applied_model = oneshot(
    model=model,
    **oneshot_kwargs,
    stage="sparsity_stage",
)

finetune_applied_model = train(
    model=oneshot_applied_model,
    **oneshot_kwargs,
    **training_kwargs,
    stage="finetuning_stage",
)

oneshot_applied_model = oneshot(
    model=finetune_applied_model,
    **oneshot_kwargs,
    stage="quantization_stage",
)

@horheynm horheynm changed the title Stage run [StageRunner] Stage Runner entrypoint and pipeline Feb 28, 2025
@horheynm horheynm marked this pull request as ready for review March 14, 2025 11:40
@horheynm horheynm added the ready When a PR is ready for review label Mar 14, 2025
@horheynm horheynm removed the ready When a PR is ready for review label Mar 14, 2025
@horheynm horheynm added the ready When a PR is ready for review label Mar 15, 2025
kylesayrs
kylesayrs previously approved these changes Mar 18, 2025
Copy link
Collaborator

@kylesayrs kylesayrs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, I much prefer this interface

kylesayrs
kylesayrs previously approved these changes Mar 26, 2025
Copy link
Collaborator

@brian-dellabetta brian-dellabetta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

few clarifying questions but i think is ready to roll!

@dsikka dsikka enabled auto-merge (squash) March 30, 2025 21:24
Copy link
Collaborator

@dsikka dsikka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is currently failing the multi-stage example: llama7b_sparse_w4a16.py under llm-compressor/examples/quantization_2of4_sparse_w4a16 - seems to get get past the sparsity stage and then fails during finetune.

@kylesayrs can you take a look?

@dsikka dsikka merged commit 1acf393 into main Apr 1, 2025
8 checks passed
@dsikka dsikka deleted the stage-run branch April 1, 2025 15:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready When a PR is ready for review
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants