22from einops import rearrange
33from torch import Tensor
44from comfy .ldm .modules .attention import optimized_attention
5+ import comfy .model_management
56
67def attention (q : Tensor , k : Tensor , v : Tensor , pe : Tensor ) -> Tensor :
78 q , k = apply_rope (q , k , pe )
@@ -13,12 +14,17 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
1314
1415def rope (pos : Tensor , dim : int , theta : int ) -> Tensor :
1516 assert dim % 2 == 0
16- scale = torch .linspace (0 , (dim - 2 ) / dim , steps = dim // 2 , dtype = torch .float64 , device = pos .device )
17+ if comfy .model_management .is_device_mps (pos .device ):
18+ device = torch .device ("cpu" )
19+ else :
20+ device = pos .device
21+
22+ scale = torch .linspace (0 , (dim - 2 ) / dim , steps = dim // 2 , dtype = torch .float64 , device = device )
1723 omega = 1.0 / (theta ** scale )
18- out = torch .einsum ("...n,d->...nd" , pos .float ( ), omega )
24+ out = torch .einsum ("...n,d->...nd" , pos .to ( dtype = torch . float32 , device = device ), omega )
1925 out = torch .stack ([torch .cos (out ), - torch .sin (out ), torch .sin (out ), torch .cos (out )], dim = - 1 )
2026 out = rearrange (out , "b n d (i j) -> b n d i j" , i = 2 , j = 2 )
21- return out .float ( )
27+ return out .to ( dtype = torch . float32 , device = pos . device )
2228
2329
2430def apply_rope (xq : Tensor , xk : Tensor , freqs_cis : Tensor ):
0 commit comments