@@ -38,7 +38,9 @@ def encode_token_weights(self, token_weight_pairs):
3838 if has_weights or sections == 0 :
3939 to_encode .append (gen_empty_tokens (self .special_tokens , max_token_len ))
4040
41- out , pooled = self .encode (to_encode )
41+ o = self .encode (to_encode )
42+ out , pooled = o [:2 ]
43+
4244 if pooled is not None :
4345 first_pooled = pooled [0 :1 ].to (model_management .intermediate_device ())
4446 else :
@@ -57,8 +59,11 @@ def encode_token_weights(self, token_weight_pairs):
5759 output .append (z )
5860
5961 if (len (output ) == 0 ):
60- return out [- 1 :].to (model_management .intermediate_device ()), first_pooled
61- return torch .cat (output , dim = - 2 ).to (model_management .intermediate_device ()), first_pooled
62+ r = (out [- 1 :].to (model_management .intermediate_device ()), first_pooled )
63+ else :
64+ r = (torch .cat (output , dim = - 2 ).to (model_management .intermediate_device ()), first_pooled )
65+ r = r + tuple (map (lambda a : a [:sections ].flatten ().unsqueeze (dim = 0 ).to (model_management .intermediate_device ()), o [2 :]))
66+ return r
6267
6368class SDClipModel (torch .nn .Module , ClipTokenWeightEncoder ):
6469 """Uses the CLIP transformer encoder for text (from huggingface)"""
@@ -70,7 +75,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
7075 def __init__ (self , version = "openai/clip-vit-large-patch14" , device = "cpu" , max_length = 77 ,
7176 freeze = True , layer = "last" , layer_idx = None , textmodel_json_config = None , dtype = None , model_class = comfy .clip_model .CLIPTextModel ,
7277 special_tokens = {"start" : 49406 , "end" : 49407 , "pad" : 49407 }, layer_norm_hidden_state = True , enable_attention_masks = False , zero_out_masked = False ,
73- return_projected_pooled = True ): # clip-vit-base-patch32
78+ return_projected_pooled = True , return_attention_masks = False ): # clip-vit-base-patch32
7479 super ().__init__ ()
7580 assert layer in self .LAYERS
7681
@@ -96,6 +101,7 @@ def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_le
96101
97102 self .layer_norm_hidden_state = layer_norm_hidden_state
98103 self .return_projected_pooled = return_projected_pooled
104+ self .return_attention_masks = return_attention_masks
99105
100106 if layer == "hidden" :
101107 assert layer_idx is not None
@@ -169,7 +175,7 @@ def forward(self, tokens):
169175 tokens = torch .LongTensor (tokens ).to (device )
170176
171177 attention_mask = None
172- if self .enable_attention_masks or self .zero_out_masked :
178+ if self .enable_attention_masks or self .zero_out_masked or self . return_attention_masks :
173179 attention_mask = torch .zeros_like (tokens )
174180 end_token = self .special_tokens .get ("end" , - 1 )
175181 for x in range (attention_mask .shape [0 ]):
@@ -200,6 +206,9 @@ def forward(self, tokens):
200206 elif outputs [2 ] is not None :
201207 pooled_output = outputs [2 ].float ()
202208
209+ if self .return_attention_masks :
210+ return z , pooled_output , attention_mask
211+
203212 return z , pooled_output
204213
205214 def encode (self , tokens ):
0 commit comments