Skip to content

Commit a5f4292

Browse files
Basic hunyuan dit implementation. (comfyanonymous#4102)
* Let tokenizers return weights to be stored in the saved checkpoint. * Basic hunyuan dit implementation. * Fix some resolutions not working. * Support hydit checkpoint save. * Init with right dtype. * Switch to optimized attention in pooler. * Fix black images on hunyuan dit.
1 parent f87810c commit a5f4292

File tree

15 files changed

+48196
-1
lines changed

15 files changed

+48196
-1
lines changed

comfy/ldm/hydit/attn_layers.py

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
import torch
2+
import torch.nn as nn
3+
from typing import Tuple, Union, Optional
4+
from comfy.ldm.modules.attention import optimized_attention
5+
6+
7+
def reshape_for_broadcast(freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], x: torch.Tensor, head_first=False):
8+
"""
9+
Reshape frequency tensor for broadcasting it with another tensor.
10+
11+
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
12+
for the purpose of broadcasting the frequency tensor during element-wise operations.
13+
14+
Args:
15+
freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
16+
x (torch.Tensor): Target tensor for broadcasting compatibility.
17+
head_first (bool): head dimension first (except batch dim) or not.
18+
19+
Returns:
20+
torch.Tensor: Reshaped frequency tensor.
21+
22+
Raises:
23+
AssertionError: If the frequency tensor doesn't match the expected shape.
24+
AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
25+
"""
26+
ndim = x.ndim
27+
assert 0 <= 1 < ndim
28+
29+
if isinstance(freqs_cis, tuple):
30+
# freqs_cis: (cos, sin) in real space
31+
if head_first:
32+
assert freqs_cis[0].shape == (x.shape[-2], x.shape[-1]), f'freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}'
33+
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
34+
else:
35+
assert freqs_cis[0].shape == (x.shape[1], x.shape[-1]), f'freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}'
36+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
37+
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
38+
else:
39+
# freqs_cis: values in complex space
40+
if head_first:
41+
assert freqs_cis.shape == (x.shape[-2], x.shape[-1]), f'freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}'
42+
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
43+
else:
44+
assert freqs_cis.shape == (x.shape[1], x.shape[-1]), f'freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}'
45+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
46+
return freqs_cis.view(*shape)
47+
48+
49+
def rotate_half(x):
50+
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
51+
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
52+
53+
54+
def apply_rotary_emb(
55+
xq: torch.Tensor,
56+
xk: Optional[torch.Tensor],
57+
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
58+
head_first: bool = False,
59+
) -> Tuple[torch.Tensor, torch.Tensor]:
60+
"""
61+
Apply rotary embeddings to input tensors using the given frequency tensor.
62+
63+
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
64+
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
65+
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
66+
returned as real tensors.
67+
68+
Args:
69+
xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
70+
xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
71+
freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Precomputed frequency tensor for complex exponentials.
72+
head_first (bool): head dimension first (except batch dim) or not.
73+
74+
Returns:
75+
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
76+
77+
"""
78+
xk_out = None
79+
if isinstance(freqs_cis, tuple):
80+
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
81+
cos, sin = cos.to(xq.device), sin.to(xq.device)
82+
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
83+
if xk is not None:
84+
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
85+
else:
86+
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2]
87+
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2]
88+
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
89+
if xk is not None:
90+
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # [B, S, H, D//2]
91+
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
92+
93+
return xq_out, xk_out
94+
95+
96+
97+
class CrossAttention(nn.Module):
98+
"""
99+
Use QK Normalization.
100+
"""
101+
def __init__(self,
102+
qdim,
103+
kdim,
104+
num_heads,
105+
qkv_bias=True,
106+
qk_norm=False,
107+
attn_drop=0.0,
108+
proj_drop=0.0,
109+
attn_precision=None,
110+
device=None,
111+
dtype=None,
112+
operations=None,
113+
):
114+
factory_kwargs = {'device': device, 'dtype': dtype}
115+
super().__init__()
116+
self.attn_precision = attn_precision
117+
self.qdim = qdim
118+
self.kdim = kdim
119+
self.num_heads = num_heads
120+
assert self.qdim % num_heads == 0, "self.qdim must be divisible by num_heads"
121+
self.head_dim = self.qdim // num_heads
122+
assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
123+
self.scale = self.head_dim ** -0.5
124+
125+
self.q_proj = operations.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
126+
self.kv_proj = operations.Linear(kdim, 2 * qdim, bias=qkv_bias, **factory_kwargs)
127+
128+
# TODO: eps should be 1 / 65530 if using fp16
129+
self.q_norm = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) if qk_norm else nn.Identity()
130+
self.k_norm = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) if qk_norm else nn.Identity()
131+
self.attn_drop = nn.Dropout(attn_drop)
132+
self.out_proj = operations.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
133+
self.proj_drop = nn.Dropout(proj_drop)
134+
135+
def forward(self, x, y, freqs_cis_img=None):
136+
"""
137+
Parameters
138+
----------
139+
x: torch.Tensor
140+
(batch, seqlen1, hidden_dim) (where hidden_dim = num heads * head dim)
141+
y: torch.Tensor
142+
(batch, seqlen2, hidden_dim2)
143+
freqs_cis_img: torch.Tensor
144+
(batch, hidden_dim // 2), RoPE for image
145+
"""
146+
b, s1, c = x.shape # [b, s1, D]
147+
_, s2, c = y.shape # [b, s2, 1024]
148+
149+
q = self.q_proj(x).view(b, s1, self.num_heads, self.head_dim) # [b, s1, h, d]
150+
kv = self.kv_proj(y).view(b, s2, 2, self.num_heads, self.head_dim) # [b, s2, 2, h, d]
151+
k, v = kv.unbind(dim=2) # [b, s, h, d]
152+
q = self.q_norm(q)
153+
k = self.k_norm(k)
154+
155+
# Apply RoPE if needed
156+
if freqs_cis_img is not None:
157+
qq, _ = apply_rotary_emb(q, None, freqs_cis_img)
158+
assert qq.shape == q.shape, f'qq: {qq.shape}, q: {q.shape}'
159+
q = qq
160+
161+
q = q.transpose(-2, -3).contiguous() # q -> B, L1, H, C - B, H, L1, C
162+
k = k.transpose(-2, -3).contiguous() # k -> B, L2, H, C - B, H, C, L2
163+
v = v.transpose(-2, -3).contiguous()
164+
165+
context = optimized_attention(q, k, v, self.num_heads, skip_reshape=True, attn_precision=self.attn_precision)
166+
167+
out = self.out_proj(context) # context.reshape - B, L1, -1
168+
out = self.proj_drop(out)
169+
170+
out_tuple = (out,)
171+
172+
return out_tuple
173+
174+
175+
class Attention(nn.Module):
176+
"""
177+
We rename some layer names to align with flash attention
178+
"""
179+
def __init__(self, dim, num_heads, qkv_bias=True, qk_norm=False, attn_drop=0., proj_drop=0., attn_precision=None, dtype=None, device=None, operations=None):
180+
super().__init__()
181+
self.attn_precision = attn_precision
182+
self.dim = dim
183+
self.num_heads = num_heads
184+
assert self.dim % num_heads == 0, 'dim should be divisible by num_heads'
185+
self.head_dim = self.dim // num_heads
186+
# This assertion is aligned with flash attention
187+
assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
188+
self.scale = self.head_dim ** -0.5
189+
190+
# qkv --> Wqkv
191+
self.Wqkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
192+
# TODO: eps should be 1 / 65530 if using fp16
193+
self.q_norm = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) if qk_norm else nn.Identity()
194+
self.k_norm = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) if qk_norm else nn.Identity()
195+
self.attn_drop = nn.Dropout(attn_drop)
196+
self.out_proj = operations.Linear(dim, dim, dtype=dtype, device=device)
197+
self.proj_drop = nn.Dropout(proj_drop)
198+
199+
def forward(self, x, freqs_cis_img=None):
200+
B, N, C = x.shape
201+
qkv = self.Wqkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) # [3, b, h, s, d]
202+
q, k, v = qkv.unbind(0) # [b, h, s, d]
203+
q = self.q_norm(q) # [b, h, s, d]
204+
k = self.k_norm(k) # [b, h, s, d]
205+
206+
# Apply RoPE if needed
207+
if freqs_cis_img is not None:
208+
qq, kk = apply_rotary_emb(q, k, freqs_cis_img, head_first=True)
209+
assert qq.shape == q.shape and kk.shape == k.shape, \
210+
f'qq: {qq.shape}, q: {q.shape}, kk: {kk.shape}, k: {k.shape}'
211+
q, k = qq, kk
212+
213+
x = optimized_attention(q, k, v, self.num_heads, skip_reshape=True, attn_precision=self.attn_precision)
214+
x = self.out_proj(x)
215+
x = self.proj_drop(x)
216+
217+
out_tuple = (x,)
218+
219+
return out_tuple

0 commit comments

Comments
 (0)