1
1
import inspect
2
2
from collections import deque
3
3
from dataclasses import dataclass
4
- from typing import Any , Callable , Dict , List , Set , Union
4
+ from typing import Any , Callable , Dict , List , Optional , Set , Union
5
5
6
6
from compressed_tensors import has_offloaded_params
7
7
from compressed_tensors .quantization import find_name_or_class_matches
8
8
from torch .fx import Graph , GraphModule , Node
9
+ from torch .fx .graph import PythonCode
9
10
from torch .fx .proxy import Argument
10
11
from torch .nn import Module
11
12
from transformers import PreTrainedModel
@@ -32,16 +33,33 @@ class Subgraph:
32
33
graph : Graph
33
34
input_names : Set [str ]
34
35
consumed_names : Set [str ]
36
+ _code : Optional [PythonCode ] = None
35
37
36
- def compile_forward (self ) -> Callable [[ Any ] , Any ]:
38
+ def forward (self , * args , ** kwargs ) -> Dict [ str , Any ]:
37
39
"""
38
- Generate and compile code for executing this subgraph
40
+ Execute the operations within the subgraph
39
41
40
- :return: function which, when called, executes this subgraph
42
+ :param \\ *args: argument inputs to subgraph forward function
43
+ :param \\ **kwargs: keyword inputs to subgraph forward function
44
+ :return keyword outputs of subgraph forward function (non-consumed variables):
41
45
"""
42
- code = self .graph .python_code ("self" )
43
- exec (code .src , code .globals )
44
- return code .globals .get ("forward" )
46
+ if self ._code is None :
47
+ self ._code = self .graph .python_code ("self" )
48
+ exec (self ._code .src , self ._code .globals )
49
+
50
+ forward_fn = self ._code .globals .get ("forward" )
51
+
52
+ try :
53
+ outputs = forward_fn (* args , ** kwargs )
54
+ except Exception as exception :
55
+ raise RuntimeError (
56
+ "Raised an exception during execution of the following code:\n "
57
+ f"```\n { add_line_numbers (self ._code .src )} \n ```\n "
58
+ "This is likely due to a violation of shape assumptions made when "
59
+ "tracing"
60
+ ) from exception
61
+
62
+ return outputs
45
63
46
64
47
65
def trace_subgraphs (
@@ -376,3 +394,9 @@ def match_modules(model: Module, target_names: List[str]) -> Set[Module]:
376
394
for name , module in model .named_modules ()
377
395
if find_name_or_class_matches (name , module , target_names )
378
396
)
397
+
398
+
399
+ def add_line_numbers (text : str ) -> str :
400
+ lines = text .splitlines ()
401
+ numbered_lines = [f"{ i + 1 } { line } " for i , line in enumerate (lines )]
402
+ return "\n " .join (numbered_lines )
0 commit comments