@@ -126,11 +126,13 @@ def forward(self, x, timestep, context, y, guidance, **kwargs):
126126 bs , c , h , w = x .shape
127127 img = rearrange (x , "b c (h ph) (w pw) -> b (h w) (c ph pw)" , ph = 2 , pw = 2 )
128128
129- img_ids = torch .zeros ((h // 2 , w // 2 , 3 ), device = x .device , dtype = x .dtype )
130- img_ids [..., 1 ] = img_ids [..., 1 ] + torch .arange (h // 2 , device = x .device , dtype = x .dtype )[:, None ]
131- img_ids [..., 2 ] = img_ids [..., 2 ] + torch .arange (w // 2 , device = x .device , dtype = x .dtype )[None , :]
129+ h_len = (h // 2 )
130+ w_len = (w // 2 )
131+ img_ids = torch .zeros ((h_len , w_len , 3 ), device = x .device , dtype = x .dtype )
132+ img_ids [..., 1 ] = img_ids [..., 1 ] + torch .linspace (0 , h_len - 1 , steps = h_len , device = x .device , dtype = x .dtype )[:, None ]
133+ img_ids [..., 2 ] = img_ids [..., 2 ] + torch .linspace (0 , w_len - 1 , steps = w_len , device = x .device , dtype = x .dtype )[None , :]
132134 img_ids = repeat (img_ids , "h w c -> b (h w) c" , b = bs )
133135
134136 txt_ids = torch .zeros ((bs , context .shape [1 ], 3 ), device = x .device , dtype = x .dtype )
135137 out = self .forward_orig (img , img_ids , context , txt_ids , timestep , y , guidance )
136- return rearrange (out , "b (h w) (c ph pw) -> b c (h ph) (w pw)" , h = h // 2 , w = w // 2 , ph = 2 , pw = 2 )
138+ return rearrange (out , "b (h w) (c ph pw) -> b c (h ph) (w pw)" , h = h_len , w = w_len , ph = 2 , pw = 2 )
0 commit comments