Skip to content

Adding Q-LoRA support for Florence-2 #211

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/welcome.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ on:
pull_request_target:
types: [opened]

permissions:
pull-requests: write
issues: write

jobs:
build:
name: 👋 Welcome
Expand Down
22 changes: 20 additions & 2 deletions maestro/trainer/models/florence_2/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM, AutoProcessor
from transformers import AutoModelForCausalLM, AutoProcessor, BitsAndBytesConfig

from maestro.trainer.common.utils.device import parse_device_spec
from maestro.trainer.logger import get_maestro_logger
Expand All @@ -26,6 +26,7 @@ class OptimizationStrategy(Enum):
"""Enumeration for optimization strategies."""

LORA = "lora"
QLORA = "qlora"
FREEZE = "freeze"
NONE = "none"

Expand Down Expand Up @@ -58,7 +59,7 @@ def load_model(
device = parse_device_spec(device)
processor = AutoProcessor.from_pretrained(model_id_or_path, trust_remote_code=True, revision=revision)

if optimization_strategy == OptimizationStrategy.LORA:
if optimization_strategy in (OptimizationStrategy.LORA, OptimizationStrategy.QLORA):
default_params = DEFAULT_FLORENCE2_PEFT_PARAMS
if peft_advanced_params is not None:
default_params.update(peft_advanced_params)
Expand All @@ -71,13 +72,30 @@ def load_model(
else:
logger.info("No LoRA parameters provided. Using default configuration.")
config = LoraConfig(**default_params)

bnb_config = None
if optimization_strategy == OptimizationStrategy.QLORA:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
logger.info("Using 4-bit quantization")

model = AutoModelForCausalLM.from_pretrained(
model_id_or_path,
revision=revision,
trust_remote_code=True,
cache_dir=cache_dir,
quantization_config=bnb_config,
device_map="auto" if optimization_strategy == OptimizationStrategy.QLORA else None,
)
model = get_peft_model(model, config).to(device)

if optimization_strategy == OptimizationStrategy.QLORA:
model = model.to(device)

model.print_trainable_parameters()
else:
model = AutoModelForCausalLM.from_pretrained(
Expand Down
2 changes: 1 addition & 1 deletion maestro/trainer/models/florence_2/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class Florence2Configuration:
model_id: str = DEFAULT_FLORENCE2_MODEL_ID
revision: str = DEFAULT_FLORENCE2_MODEL_REVISION
device: str | torch.device = "auto"
optimization_strategy: Literal["lora", "freeze", "none"] = "lora"
optimization_strategy: Literal["lora", "qlora", "freeze", "none"] = "lora"
cache_dir: Optional[str] = None
epochs: int = 10
lr: float = 1e-5
Expand Down
2 changes: 1 addition & 1 deletion maestro/trainer/models/florence_2/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def train(
] = DEFAULT_FLORENCE2_MODEL_REVISION,
device: Annotated[str, typer.Option("--device", help="Device to use for training")] = "auto",
optimization_strategy: Annotated[
str, typer.Option("--optimization_strategy", help="Optimization strategy: lora, freeze, or none")
str, typer.Option("--optimization_strategy", help="Optimization strategy: lora, qlora, freeze, or none")
] = "lora",
cache_dir: Annotated[
Optional[str], typer.Option("--cache_dir", help="Directory to cache the model weights locally")
Expand Down
Loading