Skip to content

Commit 49e7e84

Browse files
committed
Merge branch 'main' into kc/gemma-3-tracing-support
2 parents 0095ac2 + 90c4075 commit 49e7e84

File tree

2 files changed

+33
-12
lines changed

2 files changed

+33
-12
lines changed

src/llmcompressor/pipelines/sequential/helpers.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import inspect
22
from collections import deque
33
from dataclasses import dataclass
4-
from typing import Any, Callable, Dict, List, Set, Union
4+
from typing import Any, Callable, Dict, List, Optional, Set, Union
55

66
from compressed_tensors import has_offloaded_params
77
from compressed_tensors.quantization import find_name_or_class_matches
88
from torch.fx import Graph, GraphModule, Node
9+
from torch.fx.graph import PythonCode
910
from torch.fx.proxy import Argument
1011
from torch.nn import Module
1112
from transformers import PreTrainedModel
@@ -32,16 +33,33 @@ class Subgraph:
3233
graph: Graph
3334
input_names: Set[str]
3435
consumed_names: Set[str]
36+
_code: Optional[PythonCode] = None
3537

36-
def compile_forward(self) -> Callable[[Any], Any]:
38+
def forward(self, *args, **kwargs) -> Dict[str, Any]:
3739
"""
38-
Generate and compile code for executing this subgraph
40+
Execute the operations within the subgraph
3941
40-
:return: function which, when called, executes this subgraph
42+
:param \\*args: argument inputs to subgraph forward function
43+
:param \\**kwargs: keyword inputs to subgraph forward function
44+
:return keyword outputs of subgraph forward function (non-consumed variables):
4145
"""
42-
code = self.graph.python_code("self")
43-
exec(code.src, code.globals)
44-
return code.globals.get("forward")
46+
if self._code is None:
47+
self._code = self.graph.python_code("self")
48+
exec(self._code.src, self._code.globals)
49+
50+
forward_fn = self._code.globals.get("forward")
51+
52+
try:
53+
outputs = forward_fn(*args, **kwargs)
54+
except Exception as exception:
55+
raise RuntimeError(
56+
"Raised an exception during execution of the following code:\n"
57+
f"```\n{add_line_numbers(self._code.src)}\n```\n"
58+
"This is likely due to a violation of shape assumptions made when "
59+
"tracing"
60+
) from exception
61+
62+
return outputs
4563

4664

4765
def trace_subgraphs(
@@ -376,3 +394,9 @@ def match_modules(model: Module, target_names: List[str]) -> Set[Module]:
376394
for name, module in model.named_modules()
377395
if find_name_or_class_matches(name, module, target_names)
378396
)
397+
398+
399+
def add_line_numbers(text: str) -> str:
400+
lines = text.splitlines()
401+
numbered_lines = [f"{i + 1} {line}" for i, line in enumerate(lines)]
402+
return "\n".join(numbered_lines)

src/llmcompressor/pipelines/sequential/pipeline.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,10 @@ def run_pipeline(
6161
calib_desc = f"({subgraph_index + 1}/{num_subgraphs}): Calibrating"
6262
prop_desc = f"({subgraph_index + 1}/{num_subgraphs}): Propagating"
6363

64-
# compile subgraph forward function
65-
forward_function = subgraph.compile_forward()
66-
6764
# do an preliminary pass to trigger modifier hooks
6865
for batch_index in tqdm.tqdm(range(len(dataloader)), desc=calib_desc):
6966
inputs = intermediates.fetch(batch_index, subgraph.input_names)
70-
forward_function(model, **inputs)
67+
subgraph.forward(model, **inputs)
7168

7269
# TODO: replace with a lifecycle event
7370
if callback_modifier:
@@ -78,7 +75,7 @@ def run_pipeline(
7875
with HooksMixin.disable_hooks():
7976
for batch_index in tqdm.tqdm(range(len(dataloader)), desc=prop_desc):
8077
inputs = intermediates.fetch(batch_index, subgraph.input_names)
81-
output = forward_function(model, **inputs)
78+
output = subgraph.forward(model, **inputs)
8279

8380
if subgraph_index < num_subgraphs - 1:
8481
intermediates.update(batch_index, output)

0 commit comments

Comments
 (0)