Skip to content

[LoRA] feat: support loading loras into 4bit quantized Flux models. #10578

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1982,9 +1982,19 @@ def _maybe_expand_transformer_param_shape_or_error_(
out_features = state_dict[lora_B_weight_name].shape[0]

# This means there's no need for an expansion in the params, so we simply skip.
if tuple(module_weight.shape) == (out_features, in_features):
module_weight_shape = module_weight.shape
expansion_shape = (out_features, in_features)
quantization_config = getattr(transformer, "quantization_config", None)
if quantization_config and quantization_config.quant_method == "bitsandbytes":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to have a utility function to get the shape to avoid code duplication?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this needs to happen. As I mentioned this PR is very much a PoC to gather feedback and I will refine it. But I wanted to first explore if this a good way to approach the problem.

if quantization_config.load_in_4bit:
expansion_shape = torch.Size(expansion_shape).numel()
expansion_shape = ((expansion_shape + 1) // 2, 1)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only 4bit bnb models flatten.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about adding a comment along the lines of: "Handle 4bit bnb weights, which are flattened and compress 2 params into 1".

I'm not quite sure why we need (shape+1) // 2, maybe this could be added to the comment too.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this comes from bitsandbytes. Cc: @matthewdouglas

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for rounding, i.e. if expansion_shape is odd it will have an additional 8bit tensor with just one value packed into it.


if tuple(module_weight_shape) == expansion_shape:
continue

# TODO (sayakpaul): We still need to consider if the module we're expanding is
# quantized and handle it accordingly if that is the case.
module_out_features, module_in_features = module_weight.shape
debug_message = ""
if in_features > module_in_features:
Expand Down Expand Up @@ -2080,13 +2090,22 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
base_weight_param = transformer_state_dict[base_param_name]
lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"]

if base_weight_param.shape[1] > lora_A_param.shape[1]:
# TODO (sayakpaul): Handle the cases when we actually need to expand.
base_out_feature_shape = base_weight_param.shape[1]
lora_A_out_feature_shape = lora_A_param.shape[1]
quantization_config = getattr(transformer, "quantization_config", None)
if quantization_config and quantization_config.quant_method == "bitsandbytes":
if quantization_config.load_in_4bit:
lora_A_out_feature_shape = lora_A_param.shape.numel()
lora_A_out_feature_shape = ((lora_A_out_feature_shape + 1) // 2, 1)[1]

if base_out_feature_shape > lora_A_out_feature_shape:
shape = (lora_A_param.shape[0], base_weight_param.shape[1])
expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device)
expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param)
lora_state_dict[f"{prefix}{k}.lora_A.weight"] = expanded_state_dict_weight
expanded_module_names.add(k)
elif base_weight_param.shape[1] < lora_A_param.shape[1]:
elif lora_A_out_feature_shape < lora_A_out_feature_shape:
raise NotImplementedError(
f"This LoRA param ({k}.lora_A.weight) has an incompatible shape {lora_A_param.shape}. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new."
)
Expand Down
24 changes: 24 additions & 0 deletions tests/quantization/bnb/test_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import numpy as np
import pytest
import safetensors.torch
from huggingface_hub import hf_hub_download

from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel
from diffusers.utils import is_accelerate_version, logging
Expand All @@ -32,6 +33,7 @@
numpy_cosine_similarity_distance,
require_accelerate,
require_bitsandbytes_version_greater,
require_peft_version_greater,
require_torch,
require_torch_gpu,
require_transformers_version_greater,
Expand Down Expand Up @@ -568,6 +570,28 @@ def test_quality(self):
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
self.assertTrue(max_diff < 1e-3)

@require_peft_version_greater("0.14.0")
def test_lora_loading_works(self):
self.pipeline_4bit.load_lora_weights(
hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd"
)
self.pipeline_4bit.set_adapters("hyper-sd", adapter_weights=0.125)

output = self.pipeline_4bit(
prompt=self.prompt,
height=256,
width=256,
max_sequence_length=64,
output_type="np",
num_inference_steps=8,
generator=torch.Generator().manual_seed(42),
).images
out_slice = output[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.5347, 0.5342, 0.5283, 0.5093, 0.4988, 0.5093, 0.5044, 0.5015, 0.4946])

max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
self.assertTrue(max_diff < 1e-3)


@slow
class BaseBnb4BitSerializationTests(Base4bitTests):
Expand Down
Loading