Skip to content
This repository was archived by the owner on Jan 24, 2024. It is now read-only.

Commit 0d1e599

Browse files
committed
update
1 parent 676d1b6 commit 0d1e599

File tree

3 files changed

+49
-22
lines changed

3 files changed

+49
-22
lines changed

sot/opcode_translator/executor/dispatcher.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
from functools import cached_property, reduce
66
from typing import TYPE_CHECKING, Any, Callable, Dict, Tuple, TypeVar
77

8-
import paddle
9-
108
from ...utils import InnerError
119

1210
if TYPE_CHECKING:

sot/opcode_translator/executor/variable_dispatch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@
9999
lambda var: var.bool(),
100100
)
101101

102-
# VariableBase
102+
# TensorVariable
103103
Dispatcher.register(
104104
paddle.is_tensor,
105105
("TensorVariable",),

sot/opcode_translator/executor/variables/basic.py

Lines changed: 48 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
BreakGraphError,
1616
NameGenerator,
1717
NotImplementException,
18-
NotImplementException,
1918
log_do,
2019
paddle_tensor_methods,
2120
)
@@ -34,21 +33,34 @@
3433
if TYPE_CHECKING:
3534
from ..function_graph import FunctionGraph
3635

37-
DTYPE_ABBRS = {
36+
37+
FP_DTYPE_ABBRS = {
3838
paddle.bfloat16: 'bfloat16',
3939
paddle.float64: 'float64',
4040
paddle.float32: 'float32',
4141
paddle.float16: 'float16',
42+
}
43+
44+
CP_DTYPE_ABBRS = {
4245
paddle.complex64: 'complex64',
4346
paddle.complex128: 'complex128',
47+
}
48+
49+
INT_DTYPE_ABBRS = {
4450
paddle.int8: 'int8',
4551
paddle.int16: 'int16',
4652
paddle.int32: 'int32',
4753
paddle.int64: 'int64',
48-
paddle.bool: 'bool',
4954
paddle.uint8: 'uint8',
5055
}
5156

57+
DTYPE_ABBRS = {
58+
**FP_DTYPE_ABBRS,
59+
**CP_DTYPE_ABBRS,
60+
**INT_DTYPE_ABBRS,
61+
paddle.bool: 'bool',
62+
}
63+
5264

5365
class ConstantVariable(VariableBase):
5466
def __init__(
@@ -174,7 +186,6 @@ def main_info(self) -> dict[str, Any]:
174186
"dtype": DTYPE_ABBRS[self.meta.dtype],
175187
"stop_gradient": self.meta.stop_gradient,
176188
"var_name": self.var_name,
177-
"var_name": self.var_name,
178189
}
179190

180191
def __getitem__(self, key):
@@ -235,28 +246,17 @@ def is_tensor(self):
235246

236247
def is_complex(self):
237248
dtype = self.meta.dtype
238-
is_cp_dtype = dtype == paddle.complex64 or dtype == paddle.complex128
249+
is_cp_dtype = dtype in CP_DTYPE_ABBRS
239250
return ConstantVariable.wrap_literal(is_cp_dtype)
240251

241252
def is_integer(self):
242253
dtype = self.meta.dtype
243-
is_int_dtype = (
244-
dtype == paddle.int8
245-
or dtype == paddle.uint8
246-
or dtype == paddle.int16
247-
or dtype == paddle.int32
248-
or dtype == paddle.int64
249-
)
254+
is_int_dtype = dtype in INT_DTYPE_ABBRS
250255
return ConstantVariable.wrap_literal(is_int_dtype)
251256

252257
def is_floating_point(self):
253258
dtype = self.meta.dtype
254-
is_fp_dtype = (
255-
dtype == paddle.float32
256-
or dtype == paddle.float64
257-
or dtype == paddle.float16
258-
or dtype == paddle.bfloat16
259-
)
259+
is_fp_dtype = dtype in FP_DTYPE_ABBRS
260260
return ConstantVariable.wrap_literal(is_fp_dtype)
261261

262262
def getattr(self, name: str):
@@ -422,7 +422,36 @@ def get_value(self) -> Any:
422422
return self.value
423423

424424
def make_stringify_guard(self) -> StringifyExpression:
425-
raise NotImplementException("We can not stringify numpy variable")
425+
if isinstance(self.get_value(), np.number):
426+
assert not isinstance(
427+
self.tracker, DummyTracker
428+
), "Can not make guard from dummy tracker"
429+
430+
frame_value_tracer = self.tracker.trace_value_from_frame()
431+
log_do(
432+
4,
433+
lambda: print(
434+
f"[Guard]: guard_fn for {self}, tracker={self.tracker.__class__.__name__}, value={frame_value_tracer.expr}"
435+
),
436+
)
437+
438+
def format_dtype(dtype: np.dtype):
439+
return f"np.{str(dtype)}"
440+
441+
def format_number(number: np.number):
442+
return f"{format_dtype(number.dtype)}({str(number.item())})"
443+
444+
return StringifyExpression(
445+
f"{frame_value_tracer.expr} == {format_number(self.get_value())}",
446+
union_free_vars(frame_value_tracer.free_vars, {"np": np}),
447+
) & StringifyExpression(
448+
f"{frame_value_tracer.expr}.dtype == {format_dtype(self.get_value().dtype)}",
449+
union_free_vars(frame_value_tracer.free_vars, {"np": np}),
450+
)
451+
else:
452+
raise NotImplementException(
453+
"We can not stringify numpy variable when value is np.ndarray"
454+
)
426455

427456
@VariableFactory.register_from_value()
428457
def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker):

0 commit comments

Comments
 (0)