|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | +from typing import Tuple, Union, Optional |
| 4 | +from comfy.ldm.modules.attention import optimized_attention |
| 5 | + |
| 6 | + |
| 7 | +def reshape_for_broadcast(freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], x: torch.Tensor, head_first=False): |
| 8 | + """ |
| 9 | + Reshape frequency tensor for broadcasting it with another tensor. |
| 10 | +
|
| 11 | + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' |
| 12 | + for the purpose of broadcasting the frequency tensor during element-wise operations. |
| 13 | +
|
| 14 | + Args: |
| 15 | + freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped. |
| 16 | + x (torch.Tensor): Target tensor for broadcasting compatibility. |
| 17 | + head_first (bool): head dimension first (except batch dim) or not. |
| 18 | +
|
| 19 | + Returns: |
| 20 | + torch.Tensor: Reshaped frequency tensor. |
| 21 | +
|
| 22 | + Raises: |
| 23 | + AssertionError: If the frequency tensor doesn't match the expected shape. |
| 24 | + AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions. |
| 25 | + """ |
| 26 | + ndim = x.ndim |
| 27 | + assert 0 <= 1 < ndim |
| 28 | + |
| 29 | + if isinstance(freqs_cis, tuple): |
| 30 | + # freqs_cis: (cos, sin) in real space |
| 31 | + if head_first: |
| 32 | + assert freqs_cis[0].shape == (x.shape[-2], x.shape[-1]), f'freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}' |
| 33 | + shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] |
| 34 | + else: |
| 35 | + assert freqs_cis[0].shape == (x.shape[1], x.shape[-1]), f'freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}' |
| 36 | + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] |
| 37 | + return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape) |
| 38 | + else: |
| 39 | + # freqs_cis: values in complex space |
| 40 | + if head_first: |
| 41 | + assert freqs_cis.shape == (x.shape[-2], x.shape[-1]), f'freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}' |
| 42 | + shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] |
| 43 | + else: |
| 44 | + assert freqs_cis.shape == (x.shape[1], x.shape[-1]), f'freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}' |
| 45 | + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] |
| 46 | + return freqs_cis.view(*shape) |
| 47 | + |
| 48 | + |
| 49 | +def rotate_half(x): |
| 50 | + x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] |
| 51 | + return torch.stack([-x_imag, x_real], dim=-1).flatten(3) |
| 52 | + |
| 53 | + |
| 54 | +def apply_rotary_emb( |
| 55 | + xq: torch.Tensor, |
| 56 | + xk: Optional[torch.Tensor], |
| 57 | + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], |
| 58 | + head_first: bool = False, |
| 59 | +) -> Tuple[torch.Tensor, torch.Tensor]: |
| 60 | + """ |
| 61 | + Apply rotary embeddings to input tensors using the given frequency tensor. |
| 62 | +
|
| 63 | + This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided |
| 64 | + frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor |
| 65 | + is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are |
| 66 | + returned as real tensors. |
| 67 | +
|
| 68 | + Args: |
| 69 | + xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D] |
| 70 | + xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D] |
| 71 | + freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Precomputed frequency tensor for complex exponentials. |
| 72 | + head_first (bool): head dimension first (except batch dim) or not. |
| 73 | +
|
| 74 | + Returns: |
| 75 | + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. |
| 76 | +
|
| 77 | + """ |
| 78 | + xk_out = None |
| 79 | + if isinstance(freqs_cis, tuple): |
| 80 | + cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D] |
| 81 | + cos, sin = cos.to(xq.device), sin.to(xq.device) |
| 82 | + xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq) |
| 83 | + if xk is not None: |
| 84 | + xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk) |
| 85 | + else: |
| 86 | + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2] |
| 87 | + freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2] |
| 88 | + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq) |
| 89 | + if xk is not None: |
| 90 | + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # [B, S, H, D//2] |
| 91 | + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk) |
| 92 | + |
| 93 | + return xq_out, xk_out |
| 94 | + |
| 95 | + |
| 96 | + |
| 97 | +class CrossAttention(nn.Module): |
| 98 | + """ |
| 99 | + Use QK Normalization. |
| 100 | + """ |
| 101 | + def __init__(self, |
| 102 | + qdim, |
| 103 | + kdim, |
| 104 | + num_heads, |
| 105 | + qkv_bias=True, |
| 106 | + qk_norm=False, |
| 107 | + attn_drop=0.0, |
| 108 | + proj_drop=0.0, |
| 109 | + attn_precision=None, |
| 110 | + device=None, |
| 111 | + dtype=None, |
| 112 | + operations=None, |
| 113 | + ): |
| 114 | + factory_kwargs = {'device': device, 'dtype': dtype} |
| 115 | + super().__init__() |
| 116 | + self.attn_precision = attn_precision |
| 117 | + self.qdim = qdim |
| 118 | + self.kdim = kdim |
| 119 | + self.num_heads = num_heads |
| 120 | + assert self.qdim % num_heads == 0, "self.qdim must be divisible by num_heads" |
| 121 | + self.head_dim = self.qdim // num_heads |
| 122 | + assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8" |
| 123 | + self.scale = self.head_dim ** -0.5 |
| 124 | + |
| 125 | + self.q_proj = operations.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs) |
| 126 | + self.kv_proj = operations.Linear(kdim, 2 * qdim, bias=qkv_bias, **factory_kwargs) |
| 127 | + |
| 128 | + # TODO: eps should be 1 / 65530 if using fp16 |
| 129 | + self.q_norm = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) if qk_norm else nn.Identity() |
| 130 | + self.k_norm = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) if qk_norm else nn.Identity() |
| 131 | + self.attn_drop = nn.Dropout(attn_drop) |
| 132 | + self.out_proj = operations.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs) |
| 133 | + self.proj_drop = nn.Dropout(proj_drop) |
| 134 | + |
| 135 | + def forward(self, x, y, freqs_cis_img=None): |
| 136 | + """ |
| 137 | + Parameters |
| 138 | + ---------- |
| 139 | + x: torch.Tensor |
| 140 | + (batch, seqlen1, hidden_dim) (where hidden_dim = num heads * head dim) |
| 141 | + y: torch.Tensor |
| 142 | + (batch, seqlen2, hidden_dim2) |
| 143 | + freqs_cis_img: torch.Tensor |
| 144 | + (batch, hidden_dim // 2), RoPE for image |
| 145 | + """ |
| 146 | + b, s1, c = x.shape # [b, s1, D] |
| 147 | + _, s2, c = y.shape # [b, s2, 1024] |
| 148 | + |
| 149 | + q = self.q_proj(x).view(b, s1, self.num_heads, self.head_dim) # [b, s1, h, d] |
| 150 | + kv = self.kv_proj(y).view(b, s2, 2, self.num_heads, self.head_dim) # [b, s2, 2, h, d] |
| 151 | + k, v = kv.unbind(dim=2) # [b, s, h, d] |
| 152 | + q = self.q_norm(q) |
| 153 | + k = self.k_norm(k) |
| 154 | + |
| 155 | + # Apply RoPE if needed |
| 156 | + if freqs_cis_img is not None: |
| 157 | + qq, _ = apply_rotary_emb(q, None, freqs_cis_img) |
| 158 | + assert qq.shape == q.shape, f'qq: {qq.shape}, q: {q.shape}' |
| 159 | + q = qq |
| 160 | + |
| 161 | + q = q.transpose(-2, -3).contiguous() # q -> B, L1, H, C - B, H, L1, C |
| 162 | + k = k.transpose(-2, -3).contiguous() # k -> B, L2, H, C - B, H, C, L2 |
| 163 | + v = v.transpose(-2, -3).contiguous() |
| 164 | + |
| 165 | + context = optimized_attention(q, k, v, self.num_heads, skip_reshape=True, attn_precision=self.attn_precision) |
| 166 | + |
| 167 | + out = self.out_proj(context) # context.reshape - B, L1, -1 |
| 168 | + out = self.proj_drop(out) |
| 169 | + |
| 170 | + out_tuple = (out,) |
| 171 | + |
| 172 | + return out_tuple |
| 173 | + |
| 174 | + |
| 175 | +class Attention(nn.Module): |
| 176 | + """ |
| 177 | + We rename some layer names to align with flash attention |
| 178 | + """ |
| 179 | + def __init__(self, dim, num_heads, qkv_bias=True, qk_norm=False, attn_drop=0., proj_drop=0., attn_precision=None, dtype=None, device=None, operations=None): |
| 180 | + super().__init__() |
| 181 | + self.attn_precision = attn_precision |
| 182 | + self.dim = dim |
| 183 | + self.num_heads = num_heads |
| 184 | + assert self.dim % num_heads == 0, 'dim should be divisible by num_heads' |
| 185 | + self.head_dim = self.dim // num_heads |
| 186 | + # This assertion is aligned with flash attention |
| 187 | + assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8" |
| 188 | + self.scale = self.head_dim ** -0.5 |
| 189 | + |
| 190 | + # qkv --> Wqkv |
| 191 | + self.Wqkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device) |
| 192 | + # TODO: eps should be 1 / 65530 if using fp16 |
| 193 | + self.q_norm = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) if qk_norm else nn.Identity() |
| 194 | + self.k_norm = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) if qk_norm else nn.Identity() |
| 195 | + self.attn_drop = nn.Dropout(attn_drop) |
| 196 | + self.out_proj = operations.Linear(dim, dim, dtype=dtype, device=device) |
| 197 | + self.proj_drop = nn.Dropout(proj_drop) |
| 198 | + |
| 199 | + def forward(self, x, freqs_cis_img=None): |
| 200 | + B, N, C = x.shape |
| 201 | + qkv = self.Wqkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) # [3, b, h, s, d] |
| 202 | + q, k, v = qkv.unbind(0) # [b, h, s, d] |
| 203 | + q = self.q_norm(q) # [b, h, s, d] |
| 204 | + k = self.k_norm(k) # [b, h, s, d] |
| 205 | + |
| 206 | + # Apply RoPE if needed |
| 207 | + if freqs_cis_img is not None: |
| 208 | + qq, kk = apply_rotary_emb(q, k, freqs_cis_img, head_first=True) |
| 209 | + assert qq.shape == q.shape and kk.shape == k.shape, \ |
| 210 | + f'qq: {qq.shape}, q: {q.shape}, kk: {kk.shape}, k: {k.shape}' |
| 211 | + q, k = qq, kk |
| 212 | + |
| 213 | + x = optimized_attention(q, k, v, self.num_heads, skip_reshape=True, attn_precision=self.attn_precision) |
| 214 | + x = self.out_proj(x) |
| 215 | + x = self.proj_drop(x) |
| 216 | + |
| 217 | + out_tuple = (x,) |
| 218 | + |
| 219 | + return out_tuple |
0 commit comments