Skip to content

Commit 8d34211

Browse files
Fix old python versions no longer working.
1 parent 1589b58 commit 8d34211

File tree

3 files changed

+8
-9
lines changed

3 files changed

+8
-9
lines changed

comfy/ldm/flux/layers.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,8 @@
88
from .math import attention, rope
99
import comfy.ops
1010

11-
1211
class EmbedND(nn.Module):
13-
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
12+
def __init__(self, dim: int, theta: int, axes_dim: list):
1413
super().__init__()
1514
self.dim = dim
1615
self.theta = theta
@@ -79,7 +78,7 @@ def __init__(self, dim: int, dtype=None, device=None, operations=None):
7978
self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
8079
self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
8180

82-
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
81+
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
8382
q = self.query_norm(q)
8483
k = self.key_norm(k)
8584
return q.to(v), k.to(v)
@@ -118,7 +117,7 @@ def __init__(self, dim: int, double: bool, dtype=None, device=None, operations=N
118117
self.multiplier = 6 if double else 3
119118
self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device)
120119

121-
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
120+
def forward(self, vec: Tensor) -> tuple:
122121
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
123122

124123
return (
@@ -156,7 +155,7 @@ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias:
156155
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
157156
)
158157

159-
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
158+
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor):
160159
img_mod1, img_mod2 = self.img_mod(vec)
161160
txt_mod1, txt_mod2 = self.txt_mod(vec)
162161

@@ -203,7 +202,7 @@ def __init__(
203202
hidden_size: int,
204203
num_heads: int,
205204
mlp_ratio: float = 4.0,
206-
qk_scale: float | None = None,
205+
qk_scale: float = None,
207206
dtype=None,
208207
device=None,
209208
operations=None

comfy/ldm/flux/math.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
2121
return out.float()
2222

2323

24-
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
24+
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
2525
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
2626
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
2727
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]

comfy/ldm/flux/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class FluxParams:
2626
num_heads: int
2727
depth: int
2828
depth_single_blocks: int
29-
axes_dim: list[int]
29+
axes_dim: list
3030
theta: int
3131
qkv_bias: bool
3232
guidance_embed: bool
@@ -92,7 +92,7 @@ def forward_orig(
9292
txt_ids: Tensor,
9393
timesteps: Tensor,
9494
y: Tensor,
95-
guidance: Tensor | None = None,
95+
guidance: Tensor = None,
9696
) -> Tensor:
9797
if img.ndim != 3 or txt.ndim != 3:
9898
raise ValueError("Input img and txt tensors must have 3 dimensions.")

0 commit comments

Comments
 (0)