Skip to content

Commit 8115d8c

Browse files
Add Flux fp16 support hack.
1 parent 6969fc9 commit 8115d8c

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

comfy/ldm/flux/layers.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,10 @@ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor):
188188
# calculate the txt bloks
189189
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
190190
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
191+
192+
if txt.dtype == torch.float16:
193+
txt = txt.clip(-65504, 65504)
194+
191195
return img, txt
192196

193197

@@ -239,7 +243,10 @@ def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
239243
attn = attention(q, k, v, pe=pe)
240244
# compute activation in mlp stream, cat again and run second linear layer
241245
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
242-
return x + mod.gate * output
246+
x = x + mod.gate * output
247+
if x.dtype == torch.float16:
248+
x = x.clip(-65504, 65504)
249+
return x
243250

244251

245252
class LastLayer(nn.Module):

comfy/supported_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -642,7 +642,7 @@ class Flux(supported_models_base.BASE):
642642

643643
memory_usage_factor = 2.8
644644

645-
supported_inference_dtypes = [torch.bfloat16, torch.float32]
645+
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
646646

647647
vae_key_prefix = ["vae."]
648648
text_encoder_key_prefix = ["text_encoders."]

0 commit comments

Comments
 (0)