Skip to content

Vision Datasets #943

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 57 commits into from
Dec 24, 2024
Merged

Vision Datasets #943

merged 57 commits into from
Dec 24, 2024

Conversation

kylesayrs
Copy link
Collaborator

@kylesayrs kylesayrs commented Nov 28, 2024

Background

Lifecycle of a dataset

Action Data form Purpose Where
Load dataset String, dataset Potentially load from HF/DVC TextGenerationDataset
Preprocess (apply chat template) Raw dataset Format data for processor TextGenerationDataset.preprocess
Rename columns Preprocessed dataset + prompt key Support datasets with text columns not named "text" TextGenerationDataset.rename_columns
Filter tokenizer args Preprocessed dataset + prompt key Remove columns not used by tokenizer (automates data_args.remove_columns TextGenerationDataset.filter_tokenizer_args
Tokenize (process) Dataset of processor kwargs + prompt key Tokenize into model kwargs TextGenerationDataset.tokenize
Postprocess Dataset of model kwargs Potentially concatenate data, add labels for finetuning TextGenerationDataset.group_text and/or TextGenerationDataset.add_labels
Apply dataloader Dataset of model kwargs Batch model kwargs for calibration/training format_calibration_data or Trainer.get_X_dataloader
Forward pass Model kwargs Train or calibrate model PretrainedModel.forward
  • The user can pass a dataset in any one of the above intermediate forms, and TextGenerationDataset will skip all previous steps and only perform the necessary subsequent steps. Typically this means passing a dataset name, raw dataset, or tokenized dataset.

Purpose

  • Clean up/ clarify dataset lifecycle
  • Support vision processors, which means supporting arbitrary arguments as input
    • Previously we relied on the user specifying the "text column" which was the only input into the tokenizer
    • Vision processors take text as input, but also many other arguments such as pixel_values, aspect_ratio, ect.
  • No longer require users to specify remove_columns, instead infer this information from the tokenizer call signature
  • Allow users to pass their own custom data_collator function, which is necessary for many vision datasets
  • Allow preprocessing_func to be used with any dataset

Changes

  • Refactor TextGenerationDataset class to clarify dataset lifecycle
  • Add Flickr30K dataset
  • Depreciate remove_columns argument
  • Add data_collator argument

Prerequisites

Sample Debug Logs

2024-12-20 20:59:00.115 | DEBUG    | llmcompressor.transformers.finetune.data.base:load_dataset:179 - Loading dataset lmms-lab/flickr30k
2024-12-20 20:59:01.073 | DEBUG    | llmcompressor.transformers.finetune.data.base:__call__:91 - Raw dataset: ['image', 'caption', 'sentids', 'img_id', 'filename']
2024-12-20 20:59:01.171 | DEBUG    | llmcompressor.transformers.finetune.data.base:__call__:102 - Dataset after preprocessing: ['image', 'caption', 'sentids', 'img_id', 'filename', 'text', 'images']
2024-12-20 20:59:01.171 | DEBUG    | llmcompressor.transformers.finetune.data.base:__call__:106 - Dataset after column renaming: ['image', 'caption', 'sentids', 'img_id', 'filename', 'text', 'images']
2024-12-20 20:59:01.171 | DEBUG    | llmcompressor.transformers.finetune.data.base:filter_tokenizer_args:232 - Found processor args `{'images', 'text', 'videos'}`. Removing all other columns
2024-12-20 20:59:01.172 | DEBUG    | llmcompressor.transformers.finetune.data.base:__call__:111 - Tokenizer args after filtering: ['text', 'images']
Tokenizing: 100%|???????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????????| 3/3 [00:00<00:00, 61.32 examples/s]
2024-12-20 20:59:01.289 | DEBUG    | llmcompressor.transformers.finetune.data.base:__call__:125 - Model kwargs after tokenizing: ['input_ids', 'attention_mask', 'pixel_values', 'image_grid_thw']
2024-12-20 20:59:01.289 | DEBUG    | llmcompressor.transformers.finetune.data.base:__call__:155 - Model kwargs after postprocessing: ['input_ids', 'attention_mask', 'pixel_values', 'image_grid_thw']

Testing

  • Able to run custom dataset examples/quantization_w4a16/llama3_example.py
  • Use the following scripts to verify that the flickr vision dataset loads and can be inputted into the model (although only the first layer will quantize due to missing changes from VLM Support via GPTQ Hooks and Data Pipelines #914)
  • Nightly
mllama.py
import os

import torch
from transformers import AutoProcessor, MllamaForConditionalGeneration

from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.transformers import oneshot

# Load model.
model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"
model = MllamaForConditionalGeneration.from_pretrained(
    model_id, device_map="auto", torch_dtype="auto"
)
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)

# Oneshot arguments
DATASET_ID = "flickr30k"
DATASET_SPLIT = "test[:512]"
NUM_CALIBRATION_SAMPLES = 1
MAX_SEQUENCE_LENGTH = 2048


# TODO: define real collators in utils
def data_collator(batch):
    assert len(batch) == 1
    return {
        "input_ids": torch.LongTensor(batch[0]["input_ids"]),
        "attention_mask": torch.tensor(batch[0]["attention_mask"]),
        "pixel_values": torch.tensor(batch[0]["pixel_values"]),
        "aspect_ratio_ids": torch.tensor(batch[0]["aspect_ratio_ids"]),
        "aspect_ratio_mask": torch.tensor(batch[0]["aspect_ratio_mask"]),
        "cross_attention_mask": torch.tensor(batch[0]["cross_attention_mask"]),
    }


# Recipe
recipe = [
    # SmoothQuantModifier(smoothing_strength=0.8, ignore=ignore),
    GPTQModifier(
        targets="Linear",
        scheme="W8A8",
        ignore=["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_model.*"],
        update_size=NUM_CALIBRATION_SAMPLES,
    ),
]

# Perform oneshot
save_name = model_id.split("/")[1] + "-W8A8"
save_path = os.path.join("./my_test/", save_name)
print("Starting quantization")
oneshot(
    model=model,
    tokenizer=model_id,
    dataset=DATASET_ID,
    splits=DATASET_SPLIT,
    recipe=recipe,
    max_seq_length=MAX_SEQUENCE_LENGTH,
    num_calibration_samples=NUM_CALIBRATION_SAMPLES,
    trust_remote_code_model=True,
    output_dir=save_path,
    data_collator=data_collator,
)

processor.save_pretrained(save_path)

# Confirm generations of the quantized model look sane.
print("========== SAMPLE GENERATION ==============")
input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=20)
print(processor.decode(output[0]))
print("==========================================")
qwen.py
import os

import torch
from compressed_tensors.quantization import (
    QuantizationArgs,
    QuantizationScheme,
    QuantizationStrategy,
    QuantizationType,
)
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration

from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.transformers import oneshot

# Load model.
model_id = "Qwen/Qwen2-VL-2B-Instruct"
model = Qwen2VLForConditionalGeneration.from_pretrained(
    model_id, device_map="auto", torch_dtype="auto"
)
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)

# Oneshot arguments
DATASET_ID = "flickr30k"
DATASET_SPLIT = "test[:3]"
NUM_CALIBRATION_SAMPLES = 1
MAX_SEQUENCE_LENGTH = 2048


# TODO: define real collators in utils
def data_collator(batch):
    assert len(batch) == 1
    return {
        "input_ids": torch.LongTensor(batch[0]["input_ids"]),
        "attention_mask": torch.tensor(batch[0]["attention_mask"]),
        "pixel_values": torch.tensor(
            batch[0]["pixel_values"]
        ),  # torch.Size([14308, 1176])
        "image_grid_thw": torch.tensor(batch[0]["image_grid_thw"]),
    }


# Recipe
recipe = GPTQModifier(
    targets="Linear",
    config_groups={
        "config_group": QuantizationScheme(
            targets=["Linear"],
            weights=QuantizationArgs(
                num_bits=4,
                type=QuantizationType.INT,
                strategy=QuantizationStrategy.GROUP,
                group_size=128,
                symmetric=True,
                dynamic=False,
                actorder="dynamic",
            ),
        ),
    },
    ignore=["re:.*lm_head"],
    update_size=NUM_CALIBRATION_SAMPLES,
    dampening_frac=0.5,
)

# Perform oneshot
save_name = model_id.split("/")[1] + "-W8A8"
save_path = os.path.join("./my_test/", save_name)
print("Starting quantization")
oneshot(
    model=model,
    tokenizer=model_id,
    # dataset=ds,
    dataset=DATASET_ID,
    splits=DATASET_SPLIT,
    recipe=recipe,
    max_seq_length=MAX_SEQUENCE_LENGTH,
    num_calibration_samples=NUM_CALIBRATION_SAMPLES,
    trust_remote_code_model=True,
    output_dir=save_path,
    data_collator=data_collator,
)

processor.save_pretrained(save_path)

# Confirm generations of the quantized model look sane.
print("========== SAMPLE GENERATION ==============")
input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=20)
print(processor.decode(output[0]))
print("==========================================")

@kylesayrs kylesayrs force-pushed the kylesayrs/cleanup-custom-dataset branch from 26769ac to 62dd240 Compare November 28, 2024 17:01
@kylesayrs kylesayrs changed the title clean up CustomDataset Clean up CustomDataset Nov 28, 2024
@kylesayrs kylesayrs self-assigned this Nov 28, 2024
@kylesayrs kylesayrs marked this pull request as draft November 28, 2024 19:07
@kylesayrs kylesayrs removed their assignment Nov 28, 2024
@kylesayrs
Copy link
Collaborator Author

@kylesayrs kylesayrs changed the title Clean up CustomDataset Vision Datasets Nov 29, 2024
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
…tokenized datasets should not be given labels

Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
@kylesayrs kylesayrs force-pushed the kylesayrs/cleanup-custom-dataset branch from 92a5f16 to 72aecfc Compare December 2, 2024 22:14
@vllm-project vllm-project deleted a comment from github-actions bot Dec 3, 2024
Copy link
Collaborator

@dsikka dsikka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think overall looks good. Mostly just a lot of questions.

I would test this against a staged workflow e.g: https://github.com/vllm-project/llm-compressor/tree/main/examples/quantization_2of4_sparse_w4a16
And I think general feedback is minimizing how much we're abstracting away from users.

Signed-off-by: Kyle Sayers <[email protected]>
dsikka
dsikka previously approved these changes Dec 20, 2024
Copy link
Collaborator

@dsikka dsikka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Why are there ???? in the logs?

@kylesayrs
Copy link
Collaborator Author

@dsikka I use gnu screen which sometimes doesn't handle unicode characters properly. This is a purely a visual bug on my side.

@dsikka dsikka requested a review from mgoin December 20, 2024 21:46
Signed-off-by: Kyle Sayers <[email protected]>
dsikka
dsikka previously approved these changes Dec 20, 2024
Copy link
Collaborator

@dsikka dsikka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved pending test failures

Signed-off-by: Kyle Sayers <[email protected]>
@kylesayrs
Copy link
Collaborator Author

Tests now pass

Copy link
Collaborator

@dsikka dsikka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the discussion!

@dsikka dsikka merged commit 3da4519 into main Dec 24, 2024
6 of 7 checks passed
@dsikka dsikka deleted the kylesayrs/cleanup-custom-dataset branch December 24, 2024 01:59
dsikka added a commit that referenced this pull request Jan 8, 2025
## Purpose ##
* Enable oneshot quantization of vision-language models

![VLM
Banner](https://github.com/user-attachments/assets/0d748714-b524-44f4-b850-a721f35d5543)
[Llama_3 2-Vision
Graphviz](https://github.com/user-attachments/assets/6b371ccc-f9f6-4bf2-b4cd-24ed75a3cad0)

## Related Issues ##
* Fixes #91
* Fixes #961
* Fixes #990

## Prerequisites ##
* neuralmagic/compressed-tensors#193
* #917
* #943
  * #955
    * #950
* #998
* #1014

## Changes ##
### VLM Support ###
* Add multimodal examples in `examples/multimodal_vision`
* Modify `custom_offload_device_map` to support models which are not
`XForCausalLM`
* Add custom data collators for VLM models in
`src/llmcompressor/transformers/utils/data_collator.py`

### GPTQModifier ###
* Implement hooks-based compression in `GPTQModifier`
* This replaces layer-compressor, which made many assumptions about
model architecture
* This also enables finer-grained sequential compression such as
[true_sequential](https://huggingface.co/docs/transformers/main_classes/quantization#transformers.GPTQConfig.true_sequential)
* Functions previously implemented in `gptq_wrapper.py` are now
implemented in `gptq_quantize.py`
* Implement `offload_hessians` parameter in `GPTQModifier`
* Implement data-pipelines-based calibration in `GPTQModifier`
* First an attempt will be made to trace the model and run the
`sequential` pipeline
* If that fails, assumptions will be made about the model architecture
and an attempt will be made to run the `layer_sequential` pipeline
* This ensures backwards compatibility with any previously supported
models
* If that fails, then the basic pipeline will be used, which is
guaranteed to run but may require using `offlo ad_hessians`
* Change hessian instability from a `ValueError` to a `_LinAlgError` so
it can be ignored by the gptq pipeline fallback mechanism
* Add support for conv2d as indicated by
[AutoGPTQ](https://github.com/AutoGPTQ/AutoGPTQ/blob/6689349625de973b9ee3016c28c11f32acf7f02c/auto_gptq/quantization/gptq.py#L45-L54)

### Data Pipelines ###
* Implement the basic skeletons of data pipelines, which are subject to
change when data pipelines are pulled out of modifiers
* Basic Pipeline
* Performs standard forward passes through the model with provided
dataloader
* Used as fallback, as well as in the future for basic calibration
passes
* Layer Sequential Pipeline
  * Refactor of `LayerCompressor` as a straight-forward data pipeline
  * Uses `IntermediatesCache` to handle activation offloading
* Sequential Pipeline
* Utilizes graph tracing implemented by `torch.fx` to trace the graph in
order to determine where sequential targets (layers) exist in the graph
and what their inputs and outputs are
  * Implements BFS algorithm to assign nodes to partitions
* An ideal implementation consolidates partition indices to assign each
node to the latest possible partition, delaying execution. The current
implementation addresses the most common case (node.op == get_attr)
* Each partition (`Subgraph`) is compiled as an executable python
function with the proper inputs and outputs
  * Uses `IntermediatesCache` to handle activation offloading
* Implement `IntermediatesCache` which automagically handles the
offloading and onloading of activations from batches
* This class is capable of offloading many non-standard activation types
such as `Tuple`s and dataclasses such as `BaseModelOutputWithPast`
  * For convenience, the class also handles masking padding
  * The class is tested in `tests/llmcompressor/pipelines/test_cache.py`

### Tracing ###
* In order to support sequential quantization of the large variety of
different multimodal model architectures, some model definitions have to
be altered to support tracing
* If the calibration dataset is text only, most LLMs and VLMs are
traceable without additional work. Multimodal calibration datasets are
more likely to require additional work to make tracable
* For many VLMs (but not all), the vision tower is not traceable without
significant work. However, this only affects sequential error
propagation and (minimal?) increased memory usage, which leaves the door
open for future support for quantizing modules in the vision tower
* Add traceable model definitions for llava, mistral, mllama, and glm
* All copyright licenses allow for alteration and redistribution, the
line `# vllm-project: no copyright` was added in similar style to
[text_generation.py](https://github.com/vllm-project/llm-compressor/blob/main/src/llmcompressor/transformers/finetune/text_generation.py#L18)

## Future Work/ Follow ups ##
* #1027
* #1032
* #1039
* #1030
* Create better data collators capable of handling larger batch sizes in
order to support VLM fine tuning
* Better support prompt masking for multimodal processors in order to
support VLM fine tuning

## Winogrande Evaluations ##

Model | Dataset | Scheme | Runtime | Winogrande |
-- | -- | -- | -- | --
Llama-3-8B | ultrachat | W4A16 | 43m, 2xA4000 | 0.7545 
Llama-3-70B | ultrachat | W4A16 | 303m, 1xH100 | 0.8216 
Mixtral-8x7B | ultrachat | W4A16 | 317m, 1xA100 | 0.8200 
openbmb/MiniCPM3-4B | ultrachat | W4A16 | 63m, 1xA100 | 0.6701 
Qwen2-VL-2B-Instruct | ultrachat | W8A8 | 12m, 2xA4000 | 0.6188 
Qwen2-VL-2B-Instruct | flickr | W8A8 | 24m, 2xA4000 | 0.6093 
Llama-3.2-11B-Vision-Instruct | flickr | W8A8 | 75m, 1xA100 | 0.7837 
Pixtral-12B-2409 | flickr | W8A8 | 52m, 1xA100 | 0.7924 
llava-1.5-7b-hf | flickr | W8A8 | 15m, 1xH100 | 0.7214 
Phi-3-vision-128k-instruct | flickr | W4A16 | 51m, 1xA100 | 0.7151 

`lm_eval --model vllm --model_args
pretrained="path/to/model",dtype=auto,max_model_len=4096,tensor_parallel_size=1,gpu_memory_utilization=0.8,enforce_eager=True,add_bos_token=True
--tasks winogrande --num_fewshot 5 --batch_size 32`
`lm_eval --model vllm --model_args
pretrained="path/to/model",dtype=bfloat16,max_model_len=4096,tensor_parallel_size=1,gpu_memory_utilization=0.8,enforce_eager=True,add_bos_token=True,max_num_seqs=1
--tasks winogrande --num_fewshot 5 --batch_size 1`

## MMMU Evaluations ##
Credit to @shubhra 

Model | Dataset | Scheme | MMMU
-- | -- | -- | --
Llama-3.2-11B-Vision | N/A | Dense | 0.4144
Llama-3.2-11B-Vision | N/A | FP8-dynamic | 0.4300
Llama-3.2-11B-Vision | flickr | W4A16 | 0.4377
Llama-3.2-11B-Vision | flickr | W4A16-group | 0.4211

Model | Dataset | Scheme | MMMU
-- | -- | -- | --
Llama-3.2-90B-Vision | N/A | Dense | 0.5388
Llama-3.2-90B-Vision | N/A | FP8-dynamic | 0.5278
Llama-3.2-90B-Vision | flickr | W4A16 | 0.5111
Llama-3.2-90B-Vision | flickr | W4A16-group | 0.5477

Model | Dataset | Scheme | MMMU
-- | -- | -- | --
Pixtral-12B-2409 | N/A | Dense | 0.5022
Pixtral-12B-2409 | N/A | FP8-dynamic | 0.5322
Pixtral-12B-2409 | flickr | W4A16 | 0.4500
Pixtral-12B-2409 | flickr | W4A16-group | 0.4689

## Testing ##
*
[Nightly](https://github.com/neuralmagic/llm-compressor-testing/actions/runs/12640439996)

---------

Signed-off-by: Kyle Sayers <[email protected]>
Co-authored-by: Dipika Sikka <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants