@@ -124,15 +124,21 @@ def forward_orig(
124124
125125 def forward (self , x , timestep , context , y , guidance , ** kwargs ):
126126 bs , c , h , w = x .shape
127- img = rearrange (x , "b c (h ph) (w pw) -> b (h w) (c ph pw)" , ph = 2 , pw = 2 )
127+ patch_size = 2
128+ pad_h = (patch_size - h % 2 ) % patch_size
129+ pad_w = (patch_size - w % 2 ) % patch_size
128130
129- h_len = (h // 2 )
130- w_len = (w // 2 )
131+ x = torch .nn .functional .pad (x , (0 , pad_w , 0 , pad_h ), mode = 'circular' )
132+
133+ img = rearrange (x , "b c (h ph) (w pw) -> b (h w) (c ph pw)" , ph = patch_size , pw = patch_size )
134+
135+ h_len = ((h + (patch_size // 2 )) // patch_size )
136+ w_len = ((w + (patch_size // 2 )) // patch_size )
131137 img_ids = torch .zeros ((h_len , w_len , 3 ), device = x .device , dtype = x .dtype )
132138 img_ids [..., 1 ] = img_ids [..., 1 ] + torch .linspace (0 , h_len - 1 , steps = h_len , device = x .device , dtype = x .dtype )[:, None ]
133139 img_ids [..., 2 ] = img_ids [..., 2 ] + torch .linspace (0 , w_len - 1 , steps = w_len , device = x .device , dtype = x .dtype )[None , :]
134140 img_ids = repeat (img_ids , "h w c -> b (h w) c" , b = bs )
135141
136142 txt_ids = torch .zeros ((bs , context .shape [1 ], 3 ), device = x .device , dtype = x .dtype )
137143 out = self .forward_orig (img , img_ids , context , txt_ids , timestep , y , guidance )
138- return rearrange (out , "b (h w) (c ph pw) -> b c (h ph) (w pw)" , h = h_len , w = w_len , ph = 2 , pw = 2 )
144+ return rearrange (out , "b (h w) (c ph pw) -> b c (h ph) (w pw)" , h = h_len , w = w_len , ph = 2 , pw = 2 )[:,:,: h ,: w ]
0 commit comments