Skip to content

Commit cf9a33d

Browse files
ydshiehelusenji
authored andcommitted
Reduce memory leak in _create_and_check_torchscript (huggingface#16691)
Co-authored-by: ydshieh <[email protected]>
1 parent 245c5e0 commit cf9a33d

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

tests/test_modeling_common.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -598,6 +598,13 @@ def test_torchscript_output_hidden_state(self):
598598
config.output_hidden_states = True
599599
self._create_and_check_torchscript(config, inputs_dict)
600600

601+
# This is copied from `torch/testing/_internal/jit_utils.py::clear_class_registry`
602+
def clear_torch_jit_class_registry(self):
603+
604+
torch._C._jit_clear_class_registry()
605+
torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore()
606+
torch.jit._state._clear_class_state()
607+
601608
def _create_and_check_torchscript(self, config, inputs_dict):
602609
if not self.test_torchscript:
603610
return
@@ -679,6 +686,10 @@ def _create_and_check_torchscript(self, config, inputs_dict):
679686

680687
self.assertTrue(models_equal)
681688

689+
# Avoid memory leak. Without this, each call increase RAM usage by ~20MB.
690+
# (Even with this call, there are still memory leak by ~0.04MB)
691+
self.clear_torch_jit_class_registry()
692+
682693
def test_torch_fx(self):
683694
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
684695
self._create_and_check_torch_fx_tracing(config, inputs_dict)

0 commit comments

Comments
 (0)