|
8 | 8 | from .math import attention, rope |
9 | 9 | import comfy.ops |
10 | 10 |
|
11 | | - |
12 | 11 | class EmbedND(nn.Module): |
13 | | - def __init__(self, dim: int, theta: int, axes_dim: list[int]): |
| 12 | + def __init__(self, dim: int, theta: int, axes_dim: list): |
14 | 13 | super().__init__() |
15 | 14 | self.dim = dim |
16 | 15 | self.theta = theta |
@@ -79,7 +78,7 @@ def __init__(self, dim: int, dtype=None, device=None, operations=None): |
79 | 78 | self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations) |
80 | 79 | self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations) |
81 | 80 |
|
82 | | - def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: |
| 81 | + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple: |
83 | 82 | q = self.query_norm(q) |
84 | 83 | k = self.key_norm(k) |
85 | 84 | return q.to(v), k.to(v) |
@@ -118,7 +117,7 @@ def __init__(self, dim: int, double: bool, dtype=None, device=None, operations=N |
118 | 117 | self.multiplier = 6 if double else 3 |
119 | 118 | self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device) |
120 | 119 |
|
121 | | - def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]: |
| 120 | + def forward(self, vec: Tensor) -> tuple: |
122 | 121 | out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1) |
123 | 122 |
|
124 | 123 | return ( |
@@ -156,7 +155,7 @@ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: |
156 | 155 | operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), |
157 | 156 | ) |
158 | 157 |
|
159 | | - def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]: |
| 158 | + def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor): |
160 | 159 | img_mod1, img_mod2 = self.img_mod(vec) |
161 | 160 | txt_mod1, txt_mod2 = self.txt_mod(vec) |
162 | 161 |
|
@@ -203,7 +202,7 @@ def __init__( |
203 | 202 | hidden_size: int, |
204 | 203 | num_heads: int, |
205 | 204 | mlp_ratio: float = 4.0, |
206 | | - qk_scale: float | None = None, |
| 205 | + qk_scale: float = None, |
207 | 206 | dtype=None, |
208 | 207 | device=None, |
209 | 208 | operations=None |
|
0 commit comments