11import torch
22import math
33from comfy .ldm .modules .attention import optimized_attention_for_device
4+ import comfy .ops
45
56class T5LayerNorm (torch .nn .Module ):
67 def __init__ (self , hidden_size , eps = 1e-6 , dtype = None , device = None , operations = None ):
@@ -11,7 +12,7 @@ def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None, operations=No
1112 def forward (self , x ):
1213 variance = x .pow (2 ).mean (- 1 , keepdim = True )
1314 x = x * torch .rsqrt (variance + self .variance_epsilon )
14- return self . weight . to ( device = x . device , dtype = x . dtype ) * x
15+ return comfy . ops . cast_to_input ( self . weight , x ) * x
1516
1617activations = {
1718 "gelu_pytorch_tanh" : lambda a : torch .nn .functional .gelu (a , approximate = "tanh" ),
@@ -82,7 +83,7 @@ def __init__(self, model_dim, inner_dim, num_heads, relative_attention_bias, dty
8283 if relative_attention_bias :
8384 self .relative_attention_num_buckets = 32
8485 self .relative_attention_max_distance = 128
85- self .relative_attention_bias = torch . nn . Embedding (self .relative_attention_num_buckets , self .num_heads , device = device )
86+ self .relative_attention_bias = operations . Embedding (self .relative_attention_num_buckets , self .num_heads , device = device , dtype = dtype )
8687
8788 @staticmethod
8889 def _relative_position_bucket (relative_position , bidirectional = True , num_buckets = 32 , max_distance = 128 ):
@@ -132,7 +133,7 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
132133 relative_buckets += torch .where (is_small , relative_position , relative_position_if_large )
133134 return relative_buckets
134135
135- def compute_bias (self , query_length , key_length , device ):
136+ def compute_bias (self , query_length , key_length , device , dtype ):
136137 """Compute binned relative position bias"""
137138 context_position = torch .arange (query_length , dtype = torch .long , device = device )[:, None ]
138139 memory_position = torch .arange (key_length , dtype = torch .long , device = device )[None , :]
@@ -143,7 +144,7 @@ def compute_bias(self, query_length, key_length, device):
143144 num_buckets = self .relative_attention_num_buckets ,
144145 max_distance = self .relative_attention_max_distance ,
145146 )
146- values = self .relative_attention_bias (relative_position_bucket ) # shape (query_length, key_length, num_heads)
147+ values = self .relative_attention_bias (relative_position_bucket , out_dtype = dtype ) # shape (query_length, key_length, num_heads)
147148 values = values .permute ([2 , 0 , 1 ]).unsqueeze (0 ) # shape (1, num_heads, query_length, key_length)
148149 return values
149150
@@ -152,7 +153,7 @@ def forward(self, x, mask=None, past_bias=None, optimized_attention=None):
152153 k = self .k (x )
153154 v = self .v (x )
154155 if self .relative_attention_bias is not None :
155- past_bias = self .compute_bias (x .shape [1 ], x .shape [1 ], x .device )
156+ past_bias = self .compute_bias (x .shape [1 ], x .shape [1 ], x .device , x . dtype )
156157
157158 if past_bias is not None :
158159 if mask is not None :
@@ -225,7 +226,7 @@ def __init__(self, config_dict, dtype, device, operations):
225226
226227 self .encoder = T5Stack (self .num_layers , model_dim , model_dim , config_dict ["d_ff" ], config_dict ["dense_act_fn" ], config_dict ["is_gated_act" ], config_dict ["num_heads" ], config_dict ["model_type" ] != "umt5" , dtype , device , operations )
227228 self .dtype = dtype
228- self .shared = torch . nn . Embedding (config_dict ["vocab_size" ], model_dim , device = device )
229+ self .shared = operations . Embedding (config_dict ["vocab_size" ], model_dim , device = device , dtype = dtype )
229230
230231 def get_input_embeddings (self ):
231232 return self .shared
@@ -234,5 +235,5 @@ def set_input_embeddings(self, embeddings):
234235 self .shared = embeddings
235236
236237 def forward (self , input_ids , * args , ** kwargs ):
237- x = self .shared (input_ids )
238+ x = self .shared (input_ids , out_dtype = kwargs . get ( "dtype" , torch . float32 ) )
238239 return self .encoder (x , * args , ** kwargs )
0 commit comments