|
34 | 34 | if TYPE_CHECKING: |
35 | 35 | from ..function_graph import FunctionGraph |
36 | 36 |
|
37 | | -DTYPE_ABBRS = { |
| 37 | + |
| 38 | +FP_DTYPE_ABBRS = { |
38 | 39 | paddle.bfloat16: 'bfloat16', |
39 | 40 | paddle.float64: 'float64', |
40 | 41 | paddle.float32: 'float32', |
41 | 42 | paddle.float16: 'float16', |
| 43 | +} |
| 44 | + |
| 45 | +CP_DTYPE_ABBRS = { |
42 | 46 | paddle.complex64: 'complex64', |
43 | 47 | paddle.complex128: 'complex128', |
| 48 | +} |
| 49 | + |
| 50 | +INT_DTYPE_ABBRS = { |
44 | 51 | paddle.int8: 'int8', |
45 | 52 | paddle.int16: 'int16', |
46 | 53 | paddle.int32: 'int32', |
47 | 54 | paddle.int64: 'int64', |
48 | | - paddle.bool: 'bool', |
49 | 55 | paddle.uint8: 'uint8', |
50 | 56 | } |
51 | 57 |
|
| 58 | +DTYPE_ABBRS = { |
| 59 | + **FP_DTYPE_ABBRS, |
| 60 | + **CP_DTYPE_ABBRS, |
| 61 | + **INT_DTYPE_ABBRS, |
| 62 | + paddle.bool: 'bool', |
| 63 | +} |
| 64 | + |
52 | 65 |
|
53 | 66 | class ConstantVariable(VariableBase): |
54 | 67 | def __init__( |
@@ -235,28 +248,17 @@ def is_tensor(self): |
235 | 248 |
|
236 | 249 | def is_complex(self): |
237 | 250 | dtype = self.meta.dtype |
238 | | - is_cp_dtype = dtype == paddle.complex64 or dtype == paddle.complex128 |
| 251 | + is_cp_dtype = dtype in CP_DTYPE_ABBRS |
239 | 252 | return ConstantVariable.wrap_literal(is_cp_dtype) |
240 | 253 |
|
241 | 254 | def is_integer(self): |
242 | 255 | dtype = self.meta.dtype |
243 | | - is_int_dtype = ( |
244 | | - dtype == paddle.int8 |
245 | | - or dtype == paddle.uint8 |
246 | | - or dtype == paddle.int16 |
247 | | - or dtype == paddle.int32 |
248 | | - or dtype == paddle.int64 |
249 | | - ) |
| 256 | + is_int_dtype = dtype in INT_DTYPE_ABBRS |
250 | 257 | return ConstantVariable.wrap_literal(is_int_dtype) |
251 | 258 |
|
252 | 259 | def is_floating_point(self): |
253 | 260 | dtype = self.meta.dtype |
254 | | - is_fp_dtype = ( |
255 | | - dtype == paddle.float32 |
256 | | - or dtype == paddle.float64 |
257 | | - or dtype == paddle.float16 |
258 | | - or dtype == paddle.bfloat16 |
259 | | - ) |
| 261 | + is_fp_dtype = dtype in FP_DTYPE_ABBRS |
260 | 262 | return ConstantVariable.wrap_literal(is_fp_dtype) |
261 | 263 |
|
262 | 264 | def getattr(self, name: str): |
|
0 commit comments