|
15 | 15 | BreakGraphError, |
16 | 16 | NameGenerator, |
17 | 17 | NotImplementException, |
18 | | - NotImplementException, |
19 | 18 | log_do, |
20 | 19 | paddle_tensor_methods, |
21 | 20 | ) |
|
34 | 33 | if TYPE_CHECKING: |
35 | 34 | from ..function_graph import FunctionGraph |
36 | 35 |
|
37 | | -DTYPE_ABBRS = { |
| 36 | + |
| 37 | +FP_DTYPE_ABBRS = { |
38 | 38 | paddle.bfloat16: 'bfloat16', |
39 | 39 | paddle.float64: 'float64', |
40 | 40 | paddle.float32: 'float32', |
41 | 41 | paddle.float16: 'float16', |
| 42 | +} |
| 43 | + |
| 44 | +CP_DTYPE_ABBRS = { |
42 | 45 | paddle.complex64: 'complex64', |
43 | 46 | paddle.complex128: 'complex128', |
| 47 | +} |
| 48 | + |
| 49 | +INT_DTYPE_ABBRS = { |
44 | 50 | paddle.int8: 'int8', |
45 | 51 | paddle.int16: 'int16', |
46 | 52 | paddle.int32: 'int32', |
47 | 53 | paddle.int64: 'int64', |
48 | | - paddle.bool: 'bool', |
49 | 54 | paddle.uint8: 'uint8', |
50 | 55 | } |
51 | 56 |
|
| 57 | +DTYPE_ABBRS = { |
| 58 | + **FP_DTYPE_ABBRS, |
| 59 | + **CP_DTYPE_ABBRS, |
| 60 | + **INT_DTYPE_ABBRS, |
| 61 | + paddle.bool: 'bool', |
| 62 | +} |
| 63 | + |
52 | 64 |
|
53 | 65 | class ConstantVariable(VariableBase): |
54 | 66 | def __init__( |
@@ -174,7 +186,6 @@ def main_info(self) -> dict[str, Any]: |
174 | 186 | "dtype": DTYPE_ABBRS[self.meta.dtype], |
175 | 187 | "stop_gradient": self.meta.stop_gradient, |
176 | 188 | "var_name": self.var_name, |
177 | | - "var_name": self.var_name, |
178 | 189 | } |
179 | 190 |
|
180 | 191 | def __getitem__(self, key): |
@@ -235,28 +246,17 @@ def is_tensor(self): |
235 | 246 |
|
236 | 247 | def is_complex(self): |
237 | 248 | dtype = self.meta.dtype |
238 | | - is_cp_dtype = dtype == paddle.complex64 or dtype == paddle.complex128 |
| 249 | + is_cp_dtype = dtype in CP_DTYPE_ABBRS |
239 | 250 | return ConstantVariable.wrap_literal(is_cp_dtype) |
240 | 251 |
|
241 | 252 | def is_integer(self): |
242 | 253 | dtype = self.meta.dtype |
243 | | - is_int_dtype = ( |
244 | | - dtype == paddle.int8 |
245 | | - or dtype == paddle.uint8 |
246 | | - or dtype == paddle.int16 |
247 | | - or dtype == paddle.int32 |
248 | | - or dtype == paddle.int64 |
249 | | - ) |
| 254 | + is_int_dtype = dtype in INT_DTYPE_ABBRS |
250 | 255 | return ConstantVariable.wrap_literal(is_int_dtype) |
251 | 256 |
|
252 | 257 | def is_floating_point(self): |
253 | 258 | dtype = self.meta.dtype |
254 | | - is_fp_dtype = ( |
255 | | - dtype == paddle.float32 |
256 | | - or dtype == paddle.float64 |
257 | | - or dtype == paddle.float16 |
258 | | - or dtype == paddle.bfloat16 |
259 | | - ) |
| 259 | + is_fp_dtype = dtype in FP_DTYPE_ABBRS |
260 | 260 | return ConstantVariable.wrap_literal(is_fp_dtype) |
261 | 261 |
|
262 | 262 | def getattr(self, name: str): |
@@ -422,7 +422,36 @@ def get_value(self) -> Any: |
422 | 422 | return self.value |
423 | 423 |
|
424 | 424 | def make_stringify_guard(self) -> StringifyExpression: |
425 | | - raise NotImplementException("We can not stringify numpy variable") |
| 425 | + if isinstance(self.get_value(), np.number): |
| 426 | + assert not isinstance( |
| 427 | + self.tracker, DummyTracker |
| 428 | + ), "Can not make guard from dummy tracker" |
| 429 | + |
| 430 | + frame_value_tracer = self.tracker.trace_value_from_frame() |
| 431 | + log_do( |
| 432 | + 4, |
| 433 | + lambda: print( |
| 434 | + f"[Guard]: guard_fn for {self}, tracker={self.tracker.__class__.__name__}, value={frame_value_tracer.expr}" |
| 435 | + ), |
| 436 | + ) |
| 437 | + |
| 438 | + def format_dtype(dtype: np.dtype): |
| 439 | + return f"np.{str(dtype)}" |
| 440 | + |
| 441 | + def format_number(number: np.number): |
| 442 | + return f"{format_dtype(number.dtype)}({str(number.item())})" |
| 443 | + |
| 444 | + return StringifyExpression( |
| 445 | + f"{frame_value_tracer.expr} == {format_number(self.get_value())}", |
| 446 | + union_free_vars(frame_value_tracer.free_vars, {"np": np}), |
| 447 | + ) & StringifyExpression( |
| 448 | + f"{frame_value_tracer.expr}.dtype == {format_dtype(self.get_value().dtype)}", |
| 449 | + union_free_vars(frame_value_tracer.free_vars, {"np": np}), |
| 450 | + ) |
| 451 | + else: |
| 452 | + raise NotImplementException( |
| 453 | + "We can not stringify numpy variable when value is np.ndarray" |
| 454 | + ) |
426 | 455 |
|
427 | 456 | @VariableFactory.register_from_value() |
428 | 457 | def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): |
|
0 commit comments