Skip to content

Commit a27125d

Browse files
MikeTkachuksayakpaul
authored andcommitted
Enabling gradient checkpointing in eval() mode (#9878)
* refactored
1 parent 55ec25c commit a27125d

34 files changed

+84
-84
lines changed

examples/community/matryoshka.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -868,7 +868,7 @@ def forward(
868868
blocks = list(zip(self.resnets, self.attentions))
869869

870870
for i, (resnet, attn) in enumerate(blocks):
871-
if self.training and self.gradient_checkpointing:
871+
if torch.is_grad_enabled() and self.gradient_checkpointing:
872872

873873
def create_custom_forward(module, return_dict=None):
874874
def custom_forward(*inputs):
@@ -1029,7 +1029,7 @@ def forward(
10291029

10301030
hidden_states = self.resnets[0](hidden_states, temb)
10311031
for attn, resnet in zip(self.attentions, self.resnets[1:]):
1032-
if self.training and self.gradient_checkpointing:
1032+
if torch.is_grad_enabled() and self.gradient_checkpointing:
10331033

10341034
def create_custom_forward(module, return_dict=None):
10351035
def custom_forward(*inputs):
@@ -1191,7 +1191,7 @@ def forward(
11911191

11921192
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
11931193

1194-
if self.training and self.gradient_checkpointing:
1194+
if torch.is_grad_enabled() and self.gradient_checkpointing:
11951195

11961196
def create_custom_forward(module, return_dict=None):
11971197
def custom_forward(*inputs):
@@ -1364,7 +1364,7 @@ def forward(
13641364

13651365
# Blocks
13661366
for block in self.transformer_blocks:
1367-
if self.training and self.gradient_checkpointing:
1367+
if torch.is_grad_enabled() and self.gradient_checkpointing:
13681368

13691369
def create_custom_forward(module, return_dict=None):
13701370
def custom_forward(*inputs):

examples/research_projects/pixart/controlnet_pixart_alpha.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def forward(
215215

216216
# 2. Blocks
217217
for block_index, block in enumerate(self.transformer.transformer_blocks):
218-
if self.training and self.gradient_checkpointing:
218+
if torch.is_grad_enabled() and self.gradient_checkpointing:
219219
# rc todo: for training and gradient checkpointing
220220
print("Gradient checkpointing is not supported for the controlnet transformer model, yet.")
221221
exit(1)

src/diffusers/models/autoencoders/autoencoder_kl_allegro.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,7 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor:
506506
sample = self.temp_conv_in(sample)
507507
sample = sample + residual
508508

509-
if self.gradient_checkpointing:
509+
if torch.is_grad_enabled() and self.gradient_checkpointing:
510510

511511
def create_custom_forward(module):
512512
def custom_forward(*inputs):
@@ -646,7 +646,7 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor:
646646

647647
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
648648

649-
if self.gradient_checkpointing:
649+
if torch.is_grad_enabled() and self.gradient_checkpointing:
650650

651651
def create_custom_forward(module):
652652
def custom_forward(*inputs):

src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,7 @@ def forward(
420420
for i, resnet in enumerate(self.resnets):
421421
conv_cache_key = f"resnet_{i}"
422422

423-
if self.training and self.gradient_checkpointing:
423+
if torch.is_grad_enabled() and self.gradient_checkpointing:
424424

425425
def create_custom_forward(module):
426426
def create_forward(*inputs):
@@ -522,7 +522,7 @@ def forward(
522522
for i, resnet in enumerate(self.resnets):
523523
conv_cache_key = f"resnet_{i}"
524524

525-
if self.training and self.gradient_checkpointing:
525+
if torch.is_grad_enabled() and self.gradient_checkpointing:
526526

527527
def create_custom_forward(module):
528528
def create_forward(*inputs):
@@ -636,7 +636,7 @@ def forward(
636636
for i, resnet in enumerate(self.resnets):
637637
conv_cache_key = f"resnet_{i}"
638638

639-
if self.training and self.gradient_checkpointing:
639+
if torch.is_grad_enabled() and self.gradient_checkpointing:
640640

641641
def create_custom_forward(module):
642642
def create_forward(*inputs):
@@ -773,7 +773,7 @@ def forward(
773773

774774
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
775775

776-
if self.training and self.gradient_checkpointing:
776+
if torch.is_grad_enabled() and self.gradient_checkpointing:
777777

778778
def create_custom_forward(module):
779779
def custom_forward(*inputs):
@@ -939,7 +939,7 @@ def forward(
939939

940940
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
941941

942-
if self.training and self.gradient_checkpointing:
942+
if torch.is_grad_enabled() and self.gradient_checkpointing:
943943

944944
def create_custom_forward(module):
945945
def custom_forward(*inputs):

src/diffusers/models/autoencoders/autoencoder_kl_mochi.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def forward(
206206
for i, (resnet, norm, attn) in enumerate(zip(self.resnets, self.norms, self.attentions)):
207207
conv_cache_key = f"resnet_{i}"
208208

209-
if self.training and self.gradient_checkpointing:
209+
if torch.is_grad_enabled() and self.gradient_checkpointing:
210210

211211
def create_custom_forward(module):
212212
def create_forward(*inputs):
@@ -311,7 +311,7 @@ def forward(
311311
for i, (resnet, norm, attn) in enumerate(zip(self.resnets, self.norms, self.attentions)):
312312
conv_cache_key = f"resnet_{i}"
313313

314-
if self.training and self.gradient_checkpointing:
314+
if torch.is_grad_enabled() and self.gradient_checkpointing:
315315

316316
def create_custom_forward(module):
317317
def create_forward(*inputs):
@@ -392,7 +392,7 @@ def forward(
392392
for i, resnet in enumerate(self.resnets):
393393
conv_cache_key = f"resnet_{i}"
394394

395-
if self.training and self.gradient_checkpointing:
395+
if torch.is_grad_enabled() and self.gradient_checkpointing:
396396

397397
def create_custom_forward(module):
398398
def create_forward(*inputs):
@@ -529,7 +529,7 @@ def forward(
529529
hidden_states = self.proj_in(hidden_states)
530530
hidden_states = hidden_states.permute(0, 4, 1, 2, 3)
531531

532-
if self.training and self.gradient_checkpointing:
532+
if torch.is_grad_enabled() and self.gradient_checkpointing:
533533

534534
def create_custom_forward(module):
535535
def create_forward(*inputs):
@@ -646,7 +646,7 @@ def forward(
646646
hidden_states = self.conv_in(hidden_states)
647647

648648
# 1. Mid
649-
if self.training and self.gradient_checkpointing:
649+
if torch.is_grad_enabled() and self.gradient_checkpointing:
650650

651651
def create_custom_forward(module):
652652
def create_forward(*inputs):

src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def forward(
9595
sample = self.conv_in(sample)
9696

9797
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
98-
if self.training and self.gradient_checkpointing:
98+
if torch.is_grad_enabled() and self.gradient_checkpointing:
9999

100100
def create_custom_forward(module):
101101
def custom_forward(*inputs):

src/diffusers/models/autoencoders/vae.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor:
142142

143143
sample = self.conv_in(sample)
144144

145-
if self.training and self.gradient_checkpointing:
145+
if torch.is_grad_enabled() and self.gradient_checkpointing:
146146

147147
def create_custom_forward(module):
148148
def custom_forward(*inputs):
@@ -291,7 +291,7 @@ def forward(
291291
sample = self.conv_in(sample)
292292

293293
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
294-
if self.training and self.gradient_checkpointing:
294+
if torch.is_grad_enabled() and self.gradient_checkpointing:
295295

296296
def create_custom_forward(module):
297297
def custom_forward(*inputs):
@@ -544,7 +544,7 @@ def forward(
544544
sample = self.conv_in(sample)
545545

546546
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
547-
if self.training and self.gradient_checkpointing:
547+
if torch.is_grad_enabled() and self.gradient_checkpointing:
548548

549549
def create_custom_forward(module):
550550
def custom_forward(*inputs):
@@ -876,7 +876,7 @@ def __init__(
876876

877877
def forward(self, x: torch.Tensor) -> torch.Tensor:
878878
r"""The forward method of the `EncoderTiny` class."""
879-
if self.training and self.gradient_checkpointing:
879+
if torch.is_grad_enabled() and self.gradient_checkpointing:
880880

881881
def create_custom_forward(module):
882882
def custom_forward(*inputs):
@@ -962,7 +962,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
962962
# Clamp.
963963
x = torch.tanh(x / 3) * 3
964964

965-
if self.training and self.gradient_checkpointing:
965+
if torch.is_grad_enabled() and self.gradient_checkpointing:
966966

967967
def create_custom_forward(module):
968968
def custom_forward(*inputs):

src/diffusers/models/controlnets/controlnet_flux.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ def forward(
329329

330330
block_samples = ()
331331
for index_block, block in enumerate(self.transformer_blocks):
332-
if self.training and self.gradient_checkpointing:
332+
if torch.is_grad_enabled() and self.gradient_checkpointing:
333333

334334
def create_custom_forward(module, return_dict=None):
335335
def custom_forward(*inputs):
@@ -363,7 +363,7 @@ def custom_forward(*inputs):
363363

364364
single_block_samples = ()
365365
for index_block, block in enumerate(self.single_transformer_blocks):
366-
if self.training and self.gradient_checkpointing:
366+
if torch.is_grad_enabled() and self.gradient_checkpointing:
367367

368368
def create_custom_forward(module, return_dict=None):
369369
def custom_forward(*inputs):

src/diffusers/models/controlnets/controlnet_sd3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def forward(
324324
block_res_samples = ()
325325

326326
for block in self.transformer_blocks:
327-
if self.training and self.gradient_checkpointing:
327+
if torch.is_grad_enabled() and self.gradient_checkpointing:
328328

329329
def create_custom_forward(module, return_dict=None):
330330
def custom_forward(*inputs):

src/diffusers/models/controlnets/controlnet_xs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1466,7 +1466,7 @@ def custom_forward(*inputs):
14661466
h_ctrl = torch.cat([h_ctrl, b2c(h_base)], dim=1)
14671467

14681468
# apply base subblock
1469-
if self.training and self.gradient_checkpointing:
1469+
if torch.is_grad_enabled() and self.gradient_checkpointing:
14701470
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
14711471
h_base = torch.utils.checkpoint.checkpoint(
14721472
create_custom_forward(b_res),
@@ -1489,7 +1489,7 @@ def custom_forward(*inputs):
14891489

14901490
# apply ctrl subblock
14911491
if apply_control:
1492-
if self.training and self.gradient_checkpointing:
1492+
if torch.is_grad_enabled() and self.gradient_checkpointing:
14931493
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
14941494
h_ctrl = torch.utils.checkpoint.checkpoint(
14951495
create_custom_forward(c_res),
@@ -1898,7 +1898,7 @@ def maybe_apply_freeu_to_subblock(hidden_states, res_h_base):
18981898
hidden_states, res_h_base = maybe_apply_freeu_to_subblock(hidden_states, res_h_base)
18991899
hidden_states = torch.cat([hidden_states, res_h_base], dim=1)
19001900

1901-
if self.training and self.gradient_checkpointing:
1901+
if torch.is_grad_enabled() and self.gradient_checkpointing:
19021902
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
19031903
hidden_states = torch.utils.checkpoint.checkpoint(
19041904
create_custom_forward(resnet),

0 commit comments

Comments
 (0)