Skip to content

Commit 8bfb1dd

Browse files
committed
Fix
1 parent 9e81057 commit 8bfb1dd

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

ppdiffusers/ppdiffusers/models/unet_2d_blocks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -720,7 +720,6 @@ def custom_forward(*inputs):
720720

721721
return custom_forward
722722

723-
hidden_states = recompute(create_custom_forward(resnet), hidden_states, temb)
724723
hidden_states = recompute(
725724
create_custom_forward(attn, return_dict=False),
726725
hidden_states,
@@ -729,6 +728,7 @@ def custom_forward(*inputs):
729728
attention_mask,
730729
encoder_attention_mask,
731730
) # [0]
731+
hidden_states = recompute(create_custom_forward(resnet), hidden_states, temb, lora_scale)
732732
else:
733733
hidden_states = attn(
734734
hidden_states,

ppdiffusers/tests/models/test_models_unet_2d_condition.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,12 +184,12 @@ def test_gradient_checkpointing(self):
184184
model_2.clear_gradients()
185185
loss_2 = (out_2 - labels).mean()
186186
loss_2.backward()
187-
# UNetMidBlock2DCrossAttn create_custom_forward increases the difference.
188-
self.assertTrue((loss - loss_2).abs() < 1e-03)
187+
# UNetMidBlock2DCrossAttn create_custom_forward associates the difference.
188+
self.assertTrue((loss - loss_2).abs() < 1e-5)
189189
named_params = dict(model.named_parameters())
190190
named_params_2 = dict(model_2.named_parameters())
191191
for name, param in named_params.items():
192-
self.assertTrue(paddle_all_close(param.grad, named_params_2[name].grad, atol=1e-03))
192+
self.assertTrue(paddle_all_close(param.grad, named_params_2[name].grad, atol=5e-5))
193193

194194
def test_model_with_attention_head_dim_tuple(self):
195195
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

0 commit comments

Comments
 (0)