Skip to content
This repository was archived by the owner on Jan 24, 2024. It is now read-only.

Commit 1affc7d

Browse files
committed
update
1 parent 676d1b6 commit 1affc7d

File tree

1 file changed

+18
-16
lines changed
  • sot/opcode_translator/executor/variables

1 file changed

+18
-16
lines changed

sot/opcode_translator/executor/variables/basic.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,21 +34,34 @@
3434
if TYPE_CHECKING:
3535
from ..function_graph import FunctionGraph
3636

37-
DTYPE_ABBRS = {
37+
38+
FP_DTYPE_ABBRS = {
3839
paddle.bfloat16: 'bfloat16',
3940
paddle.float64: 'float64',
4041
paddle.float32: 'float32',
4142
paddle.float16: 'float16',
43+
}
44+
45+
CP_DTYPE_ABBRS = {
4246
paddle.complex64: 'complex64',
4347
paddle.complex128: 'complex128',
48+
}
49+
50+
INT_DTYPE_ABBRS = {
4451
paddle.int8: 'int8',
4552
paddle.int16: 'int16',
4653
paddle.int32: 'int32',
4754
paddle.int64: 'int64',
48-
paddle.bool: 'bool',
4955
paddle.uint8: 'uint8',
5056
}
5157

58+
DTYPE_ABBRS = {
59+
**FP_DTYPE_ABBRS,
60+
**CP_DTYPE_ABBRS,
61+
**INT_DTYPE_ABBRS,
62+
paddle.bool: 'bool',
63+
}
64+
5265

5366
class ConstantVariable(VariableBase):
5467
def __init__(
@@ -235,28 +248,17 @@ def is_tensor(self):
235248

236249
def is_complex(self):
237250
dtype = self.meta.dtype
238-
is_cp_dtype = dtype == paddle.complex64 or dtype == paddle.complex128
251+
is_cp_dtype = dtype in CP_DTYPE_ABBRS
239252
return ConstantVariable.wrap_literal(is_cp_dtype)
240253

241254
def is_integer(self):
242255
dtype = self.meta.dtype
243-
is_int_dtype = (
244-
dtype == paddle.int8
245-
or dtype == paddle.uint8
246-
or dtype == paddle.int16
247-
or dtype == paddle.int32
248-
or dtype == paddle.int64
249-
)
256+
is_int_dtype = dtype in INT_DTYPE_ABBRS
250257
return ConstantVariable.wrap_literal(is_int_dtype)
251258

252259
def is_floating_point(self):
253260
dtype = self.meta.dtype
254-
is_fp_dtype = (
255-
dtype == paddle.float32
256-
or dtype == paddle.float64
257-
or dtype == paddle.float16
258-
or dtype == paddle.bfloat16
259-
)
261+
is_fp_dtype = dtype in FP_DTYPE_ABBRS
260262
return ConstantVariable.wrap_literal(is_fp_dtype)
261263

262264
def getattr(self, name: str):

0 commit comments

Comments
 (0)