Skip to content

Disable kernels during calibration (and tracing) #1454

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions src/llmcompressor/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
"eval_context",
"calibration_forward_context",
"patch_attr",
"disable_hf_kernels",
]


Expand Down Expand Up @@ -1024,6 +1025,9 @@ def DisableQuantization(module: torch.nn.Module):

@contextlib.contextmanager
def eval_context(module: torch.nn.Module):
"""
Disable pytorch training mode for the given module
"""
restore_value = module.training
try:
module.train(False) # equivalent to eval()
Expand All @@ -1033,6 +1037,21 @@ def eval_context(module: torch.nn.Module):
module.train(restore_value)


@contextlib.contextmanager
def disable_hf_kernels(model: PreTrainedModel):
"""
In transformers>=4.50.0, some module forward methods may be
replaced by calls to hf hub kernels. This has the potential
to bypass hooks added by LLM Compressor
"""
if hasattr(model, "config"):
with patch_attr(model.config, "disable_custom_kernels", True):
yield

else:
yield


@contextlib.contextmanager
def calibration_forward_context(model: PreTrainedModel):
"""
Expand All @@ -1041,12 +1060,11 @@ def calibration_forward_context(model: PreTrainedModel):
- Remove gradient calculations
- Disable the KV cache
- Disable train mode and enable eval mode
- Disable hf kernels which could bypass hooks
"""
with (
torch.no_grad(),
DisableKVCache(model),
eval_context(model),
):
with torch.no_grad(), DisableKVCache(model), eval_context(
model
), disable_hf_kernels(model):
yield


Expand Down
30 changes: 12 additions & 18 deletions tests/examples/test_quantization_2of4_sparse_w4a16.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ def test_doc_example_command(self, example_dir: str, tmp_path: Path):
readme = ReadMe(readme_path)

command = readme.get_code_block_content(position=2, lang="shell")
assert command.startswith("python"), (
"Expected shell command to start with 'python'"
)
assert command.startswith(
"python"
), "Expected shell command to start with 'python'"

command = shlex.split(command)
result = copy_and_run_command(tmp_path, example_dir, command)
Expand All @@ -62,18 +62,16 @@ def test_doc_example_command(self, example_dir: str, tmp_path: Path):
}

for stage, stage_info in stages.items():
stage_path = (
tmp_path / example_dir / output_dir / stage_info["path"]
)
stage_path = tmp_path / example_dir / output_dir / stage_info["path"]
recipe_path = stage_path / "recipe.yaml"
config_path = stage_path / "config.json"

assert recipe_path.exists(), (
f"Missing recipe file in {stage}: {recipe_path}"
)
assert config_path.exists(), (
f"Missing config file in {stage}: {config_path}"
)
assert (
recipe_path.exists()
), f"Missing recipe file in {stage}: {recipe_path}"
assert (
config_path.exists()
), f"Missing config file in {stage}: {config_path}"

config = AutoConfig.from_pretrained(stage_path)
assert config is not None, f"Failed to load config in {stage}"
Expand All @@ -82,13 +80,9 @@ def test_doc_example_command(self, example_dir: str, tmp_path: Path):
if stage == "quantization":
actual_format = quant_config.get("format")
else:
actual_format = quant_config.get(
"sparsity_config", {}
).get("format")
actual_format = quant_config.get("sparsity_config", {}).get("format")

assert actual_format, (
f"Missing expected format field in {stage} config"
)
assert actual_format, f"Missing expected format field in {stage} config"
assert actual_format == stage_info["format"], (
f"Unexpected format in {stage}: got '{actual_format}', "
f"expected '{stage_info['format']}'"
Expand Down