-
Notifications
You must be signed in to change notification settings - Fork 6.2k
[LoRA] feat: support loading loras into 4bit quantized Flux models. #10578
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
779c17b
f46ba42
d3d8ef2
8b13c1e
c92758f
a3f533b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1982,9 +1982,19 @@ def _maybe_expand_transformer_param_shape_or_error_( | |
out_features = state_dict[lora_B_weight_name].shape[0] | ||
|
||
# This means there's no need for an expansion in the params, so we simply skip. | ||
if tuple(module_weight.shape) == (out_features, in_features): | ||
module_weight_shape = module_weight.shape | ||
expansion_shape = (out_features, in_features) | ||
quantization_config = getattr(transformer, "quantization_config", None) | ||
if quantization_config and quantization_config.quant_method == "bitsandbytes": | ||
if quantization_config.load_in_4bit: | ||
expansion_shape = torch.Size(expansion_shape).numel() | ||
expansion_shape = ((expansion_shape + 1) // 2, 1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Only 4bit bnb models flatten. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about adding a comment along the lines of: "Handle 4bit bnb weights, which are flattened and compress 2 params into 1". I'm not quite sure why we need (shape+1) // 2, maybe this could be added to the comment too. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, this comes from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is for rounding, i.e. if |
||
|
||
if tuple(module_weight_shape) == expansion_shape: | ||
continue | ||
|
||
# TODO (sayakpaul): We still need to consider if the module we're expanding is | ||
# quantized and handle it accordingly if that is the case. | ||
module_out_features, module_in_features = module_weight.shape | ||
debug_message = "" | ||
if in_features > module_in_features: | ||
|
@@ -2080,13 +2090,22 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict): | |
base_weight_param = transformer_state_dict[base_param_name] | ||
lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"] | ||
|
||
if base_weight_param.shape[1] > lora_A_param.shape[1]: | ||
# TODO (sayakpaul): Handle the cases when we actually need to expand. | ||
base_out_feature_shape = base_weight_param.shape[1] | ||
lora_A_out_feature_shape = lora_A_param.shape[1] | ||
quantization_config = getattr(transformer, "quantization_config", None) | ||
if quantization_config and quantization_config.quant_method == "bitsandbytes": | ||
if quantization_config.load_in_4bit: | ||
lora_A_out_feature_shape = lora_A_param.shape.numel() | ||
lora_A_out_feature_shape = ((lora_A_out_feature_shape + 1) // 2, 1)[1] | ||
|
||
if base_out_feature_shape > lora_A_out_feature_shape: | ||
shape = (lora_A_param.shape[0], base_weight_param.shape[1]) | ||
expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device) | ||
expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param) | ||
lora_state_dict[f"{prefix}{k}.lora_A.weight"] = expanded_state_dict_weight | ||
expanded_module_names.add(k) | ||
elif base_weight_param.shape[1] < lora_A_param.shape[1]: | ||
elif lora_A_out_feature_shape < lora_A_out_feature_shape: | ||
raise NotImplementedError( | ||
f"This LoRA param ({k}.lora_A.weight) has an incompatible shape {lora_A_param.shape}. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new." | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make sense to have a utility function to get the shape to avoid code duplication?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this needs to happen. As I mentioned this PR is very much a PoC to gather feedback and I will refine it. But I wanted to first explore if this a good way to approach the problem.