@@ -319,13 +319,15 @@ def forward(self,
319319 text_states_mask = text_embedding_mask .bool () # 2,77
320320 text_states_t5_mask = text_embedding_mask_t5 .bool () # 2,256
321321 b_t5 , l_t5 , c_t5 = text_states_t5 .shape
322- text_states_t5 = self .mlp_t5 (text_states_t5 .view (- 1 , c_t5 ))
323- text_states = torch .cat ([text_states , text_states_t5 .view (b_t5 , l_t5 , - 1 )], dim = 1 ) # 2,205,1024
322+ text_states_t5 = self .mlp_t5 (text_states_t5 .view (- 1 , c_t5 )).view (b_t5 , l_t5 , - 1 )
324323
325- clip_t5_mask = torch . cat ([ text_states_mask , text_states_t5_mask ], dim = - 1 )
324+ padding = self . text_embedding_padding . to ( text_states )
326325
327- clip_t5_mask = clip_t5_mask
328- text_states = torch .where (clip_t5_mask .unsqueeze (2 ), text_states , self .text_embedding_padding .to (text_states ))
326+ text_states [:,- self .text_len :] = torch .where (text_states_mask [:,- self .text_len :].unsqueeze (2 ), text_states [:,- self .text_len :], padding [:self .text_len ])
327+ text_states_t5 [:,- self .text_len_t5 :] = torch .where (text_states_t5_mask [:,- self .text_len_t5 :].unsqueeze (2 ), text_states_t5 [:,- self .text_len_t5 :], padding [self .text_len :])
328+
329+ text_states = torch .cat ([text_states , text_states_t5 ], dim = 1 ) # 2,205,1024
330+ # clip_t5_mask = torch.cat([text_states_mask, text_states_t5_mask], dim=-1)
329331
330332 _ , _ , oh , ow = x .shape
331333 th , tw = (oh + (self .patch_size // 2 )) // self .patch_size , (ow + (self .patch_size // 2 )) // self .patch_size
0 commit comments