Skip to content

Commit 4191e06

Browse files
Trace api (#70752)
* trace_api
1 parent 01e22b4 commit 4191e06

File tree

5 files changed

+229
-0
lines changed

5 files changed

+229
-0
lines changed

python/paddle/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1223,3 +1223,13 @@
12231223
'pi',
12241224
'e',
12251225
]
1226+
1227+
import os
1228+
1229+
FLAGS_trace_api = os.environ.get("FLAGS_trace_api", None)
1230+
if FLAGS_trace_api is not None and FLAGS_trace_api != "":
1231+
from .api_tracer import start_api_tracer
1232+
1233+
api_path = FLAGS_trace_api.split(",")[0]
1234+
save_config_path = FLAGS_trace_api.split(",")[1]
1235+
start_api_tracer(api_path, save_config_path)
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .api_tracer import start_api_tracer
16+
17+
__all__ = [
18+
'api_tracer',
19+
'start_api_tracer',
20+
]
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import math
16+
17+
import numpy as np
18+
import yaml
19+
20+
21+
class HookAPIMap:
22+
pass
23+
24+
25+
class ConfigDump:
26+
def __init__(self):
27+
pass
28+
29+
def open_file(self, path):
30+
self.file = open(path, "a+")
31+
32+
def dump_config(self, api, input_args, input_kwargs, outputs):
33+
result = api + "("
34+
for value in input_args:
35+
tmp = self.dump_item_str(api, value)
36+
if tmp == "":
37+
return
38+
result = result + tmp + ", "
39+
for key, value in input_kwargs.items():
40+
tmp = self.dump_item_str(api, value)
41+
if tmp == "":
42+
return
43+
result = result + key + "=" + tmp + ", "
44+
45+
result = result + ")"
46+
# self.file.write(") -> ")
47+
# if isinstance(outputs, (list, tuple)):
48+
# for output in outputs:
49+
# self.file.write(self.dump_item_str(api, output) + ", ")
50+
# else:
51+
# self.file.write(self.dump_item_str(api, outputs) + ", ")
52+
53+
self.file.write(result)
54+
self.file.write("\n")
55+
self.file.flush()
56+
57+
def dump_item_str(self, api, item):
58+
import paddle
59+
60+
type_mapping = {
61+
np.int16: int,
62+
np.int32: int,
63+
np.int64: int,
64+
np.float16: float,
65+
np.float32: float,
66+
np.float64: float,
67+
np.integer: int,
68+
np.floating: float,
69+
np.bool_: bool,
70+
np.complexfloating: complex,
71+
np.str_: str,
72+
np.bytes_: bytes,
73+
# np.unicode_: str,
74+
}
75+
for numpy_type, builtin_type in type_mapping.items():
76+
if isinstance(item, numpy_type):
77+
item = builtin_type(item)
78+
break
79+
80+
if isinstance(item, paddle.Tensor):
81+
return (
82+
"Tensor(" + str(item.shape) + ',"' + str(item.dtype)[7:] + '")'
83+
)
84+
elif isinstance(item, paddle.base.core.DataType):
85+
return "Dtype(" + str(item)[7:] + ")"
86+
elif isinstance(item, paddle.base.core.VarDesc.VarType):
87+
return "VarType(" + str(item)[7:] + ")"
88+
elif isinstance(item, list):
89+
result = "list["
90+
for sub_item in item:
91+
tmp = self.dump_item_str(api, sub_item)
92+
if tmp == "":
93+
return ""
94+
result = result + tmp + ","
95+
result = result + "]"
96+
return result
97+
elif isinstance(item, tuple):
98+
result = "tuple("
99+
for sub_item in item:
100+
tmp = self.dump_item_str(api, sub_item)
101+
if tmp == "":
102+
return ""
103+
result = result + tmp + ","
104+
result = result + ")"
105+
return result
106+
elif isinstance(item, slice):
107+
return (
108+
"slice("
109+
+ str(item.start)
110+
+ ","
111+
+ str(item.stop)
112+
+ ","
113+
+ str(item.step)
114+
+ ")"
115+
)
116+
elif isinstance(item, complex):
117+
return (
118+
"complex("
119+
+ self.dump_item_str(api, item.real)
120+
+ ","
121+
+ self.dump_item_str(api, item.imag)
122+
+ ")"
123+
)
124+
elif item is None:
125+
return "None"
126+
elif isinstance(
127+
item, (paddle.base.Variable, paddle.base.libpaddle.pir.Value)
128+
):
129+
return ""
130+
elif item == math.inf:
131+
return "math.inf"
132+
elif item == -math.inf:
133+
return "-math.inf"
134+
elif item == math.nan:
135+
return "math.nan"
136+
elif item == -math.nan:
137+
return "-math.nan"
138+
elif isinstance(item, (bool, int, float)):
139+
return str(item)
140+
elif isinstance(item, str):
141+
return '"' + item + '"'
142+
elif isinstance(item, type):
143+
return (
144+
"type("
145+
+ str(item)[str(item).index("'") + 1 : str(item).rindex("'")]
146+
+ ")"
147+
)
148+
else:
149+
print("[api_tracer error] : dump_item_str ", api, ", item = ", item)
150+
return ""
151+
152+
153+
config_dump = ConfigDump()
154+
155+
156+
class APITemplate:
157+
def __init__(self, api_name):
158+
self.api_name = api_name
159+
160+
def __call__(self, *args, **kwargs):
161+
output = getattr(HookAPIMap, self.api_name)(*args, **kwargs)
162+
try:
163+
config_dump.dump_config(self.api_name, args, kwargs, output)
164+
except Exception as err:
165+
print(
166+
"[api_tracer error] : config_dump.dump_config ",
167+
self.api_name,
168+
str(err),
169+
)
170+
return output
171+
172+
173+
def wrapped_api(api_name):
174+
def api_template(*args, **kwargs):
175+
return APITemplate(api_name)(*args, **kwargs)
176+
177+
return api_template
178+
179+
180+
def start_api_tracer(api_path, save_config_path):
181+
import paddle
182+
183+
print(paddle.__version__)
184+
with open(api_path, "r") as f:
185+
apis = yaml.safe_load(f)
186+
sample_apis = apis.get("apis")
187+
f.close()
188+
189+
for api in sample_apis:
190+
parent_package, method_name = api.rsplit(".", maxsplit=1)
191+
try:
192+
setattr(HookAPIMap, api, getattr(eval(parent_package), method_name))
193+
setattr(eval(parent_package), method_name, wrapped_api(api))
194+
except Exception as err:
195+
print("[api_tracer error] : start_api_tracer ", api, str(err))
196+
197+
config_dump.open_file(save_config_path)

python/setup.py.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -821,6 +821,7 @@ packages=['paddle',
821821
'paddle.decomposition',
822822
'paddle._typing',
823823
'paddle._typing.libs',
824+
'paddle.api_tracer',
824825
]
825826

826827
if '@WITH_PIP_TENSORRT@' =='ON':

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2080,6 +2080,7 @@ def get_setup_parameters():
20802080
'paddle.decomposition',
20812081
'paddle._typing',
20822082
'paddle._typing.libs',
2083+
'paddle.api_tracer',
20832084
]
20842085

20852086
if env_dict.get("WITH_PIP_TENSORRT") == 'ON':

0 commit comments

Comments
 (0)