Skip to content
This repository was archived by the owner on Jan 24, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 21 additions & 12 deletions sot/infer_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,28 @@ def value_fn(self, *args, **kwargs):


class MetaInfo:
def __init__(self, shape, dtype, stop_gradient):
def __init__(
self, shape, dtype, stop_gradient, name, persistable, type, place
):
self.name = name
self.persistable = persistable
self.type = type
self.place = place
self.shape = shape
self.dtype = dtype
self.stop_gradient = stop_gradient

@staticmethod
def from_tensor(tensor):
return MetaInfo(tensor.shape, tensor.dtype, tensor.stop_gradient)
return MetaInfo(
list(tensor.shape),
tensor.dtype,
tensor.stop_gradient,
tensor.name,
tensor.persistable,
tensor.type,
tensor.place,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前这样会影响 Tensor guard

@2742195759 这些信息也应该放到 meta 里嘛?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

见AST动转静的CacheKey:
image
这里如果 MetaInfo 添加了这些,那么 MetaInfo 进行Guard判定时对齐动转静。


def is_dynamic_shape(self):
"""
Expand All @@ -40,6 +54,9 @@ def to_input_spec(self):
self.shape, dtype=self.dtype, stop_gradient=self.stop_gradient
)

def guard_str(self):
return f"({self.shape}, {self.dtype}, {self.stop_gradient})"

def __repr__(self):
return meta_str(self.shape, self.dtype, self.stop_gradient)

Expand Down Expand Up @@ -128,11 +145,7 @@ def variable_to_meta_info(args):
return map_if(
args,
pred=lambda x: isinstance(x, paddle.static.Variable),
true_fn=lambda x: MetaInfo(
list(x.shape),
x.dtype,
x.stop_gradient,
),
true_fn=lambda x: MetaInfo.from_tensor(x),
false_fn=lambda x: x,
)

Expand All @@ -153,11 +166,7 @@ def infer_meta_for_layer(layer, *args, **kwargs):
args, kwargs = convert_to_input_spec(args), convert_to_input_spec(kwargs)
concrete_program = layer.forward.get_concrete_program(*args, **kwargs)[0]
out = concrete_program.outputs[0]
out = MetaInfo(
list(out.shape),
out.dtype,
out.stop_gradient,
)
out = MetaInfo.from_tensor(out)
layer.forward.rollback()
return out

Expand Down
33 changes: 33 additions & 0 deletions sot/opcode_translator/executor/variable_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from functools import partial
from typing import TYPE_CHECKING

import paddle

from ...utils import BreakGraphError, NotImplementException
from ...utils.magic_methods import (
BINARY_OPS,
Expand Down Expand Up @@ -97,6 +99,37 @@
lambda var: var.bool(),
)

# TensorVariable
Dispatcher.register(
paddle.is_tensor,
("TensorVariable",),
{},
lambda var: var.is_tensor(),
)
Dispatcher.register(
paddle.is_complex,
("TensorVariable",),
{},
lambda var: var.is_complex(),
)
Dispatcher.register(
paddle.is_integer,
("TensorVariable",),
{},
lambda var: var.is_integer(),
)
Dispatcher.register(
paddle.is_floating_point,
("TensorVariable",),
{},
lambda var: var.is_floating_point(),
)
Dispatcher.register(
paddle.rank,
("TensorVariable",),
{},
lambda var: var.ndim,
)

# VariableBase
Dispatcher.register(
Expand Down
96 changes: 78 additions & 18 deletions sot/opcode_translator/executor/variables/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,34 @@
if TYPE_CHECKING:
from ..function_graph import FunctionGraph

DTYPE_ABBRS = {

FP_DTYPE_ABBRS = {
paddle.bfloat16: 'bfloat16',
paddle.float64: 'float64',
paddle.float32: 'float32',
paddle.float16: 'float16',
}

CP_DTYPE_ABBRS = {
paddle.complex64: 'complex64',
paddle.complex128: 'complex128',
}

INT_DTYPE_ABBRS = {
paddle.int8: 'int8',
paddle.int16: 'int16',
paddle.int32: 'int32',
paddle.int64: 'int64',
paddle.bool: 'bool',
paddle.uint8: 'uint8',
}

DTYPE_ABBRS = {
**FP_DTYPE_ABBRS,
**CP_DTYPE_ABBRS,
**INT_DTYPE_ABBRS,
paddle.bool: 'bool',
}


class ConstantVariable(VariableBase):
def __init__(
Expand Down Expand Up @@ -95,6 +108,14 @@ def wrap_literal(value: Any) -> ConstantVariable:
return ConstantVariable(value, ConstTracker(value))


IMPLEMENTED_TENSOR_PROPERTIES = set()


def tensor_property(func):
IMPLEMENTED_TENSOR_PROPERTIES.add(func.__name__)
return property(func)


class TensorVariable(VariableBase):
var_name_generator = NameGenerator("var_")

Expand Down Expand Up @@ -151,7 +172,7 @@ def make_stringify_guard(self) -> StringifyExpression:
),
)
return StringifyExpression(
f"str(MetaInfo.from_tensor({frame_value_tracer.expr})) == '{self.meta}'",
f"MetaInfo.from_tensor({frame_value_tracer.expr}).guard_str() == '{self.meta.guard_str()}'",
union_free_vars(
{"MetaInfo": MetaInfo},
frame_value_tracer.free_vars,
Expand Down Expand Up @@ -186,7 +207,7 @@ def __setitem__(self, key, value):
value,
)

@property
@tensor_property
def T(self):
perm = list(range(len(self.meta.shape) - 1, -1, -1))
perm_var = VariableFactory.from_value(
Expand All @@ -195,40 +216,79 @@ def T(self):
out = self.graph.call_paddle_api(paddle.transpose, self, perm_var)
return out

@property
@tensor_property
def ndim(self):
return ConstantVariable.wrap_literal(len(self.meta.shape))

@property
def shape(self):
@tensor_property
def size(self):
# TODO: maybe break graph.
if self.meta.is_dynamic_shape():
raise BreakGraphError(
f"Getting size for a dynamic shape tensor causes graph break. shape = {self.meta.shape}"
)
elements = reduce(operator.mul, self.meta.shape, 1)
return ConstantVariable.wrap_literal(elements)

@tensor_property
def shape(self):
if self.meta.is_dynamic_shape():
raise BreakGraphError(
f"Getting shape for a dynamic shape tensor causes graph break. shape = {self.meta.shape}"
)
self.graph.add_global_guarded_variable(self)
return VariableFactory.from_value(
self.meta.shape, self.graph, tracker=ConstTracker(self.meta.shape)
)

@property
def size(self):
# TODO: maybe break graph.
if self.meta.is_dynamic_shape():
raise BreakGraphError(
f"Getting size for a dynamic shape tensor causes graph break. shape = {self.meta.shape}"
)
elements = reduce(operator.mul, self.meta.shape, 1)
return ConstantVariable.wrap_literal(elements)
def is_tensor(self):
return ConstantVariable.wrap_literal(True)

def is_complex(self):
dtype = self.meta.dtype
is_cp_dtype = dtype in CP_DTYPE_ABBRS
return ConstantVariable.wrap_literal(is_cp_dtype)

def is_integer(self):
dtype = self.meta.dtype
is_int_dtype = dtype in INT_DTYPE_ABBRS
return ConstantVariable.wrap_literal(is_int_dtype)

def is_floating_point(self):
dtype = self.meta.dtype
is_fp_dtype = dtype in FP_DTYPE_ABBRS
return ConstantVariable.wrap_literal(is_fp_dtype)

def getattr(self, name: str):
if name in ["shape", "dtype", "stop_gradient"]:
method_name_to_builtin_fn = {
"dim": paddle.rank,
"ndimension": paddle.rank,
"is_tensor": paddle.is_tensor,
"is_complex": paddle.is_complex,
"is_integer": paddle.is_integer,
"is_floating_point": paddle.is_floating_point,
}
if name in ["dtype", "type", "name", "persistable", "stop_gradient"]:
if name == "name" and self.meta.name.startswith(
"infer_meta_variable_tmp"
):
raise BreakGraphError(f"{self.meta.name} is a middle tensor.")
return VariableFactory.from_value(
getattr(self.meta, name),
self.graph,
tracker=GetAttrTracker(self, name),
)
elif name in ["T", "ndim", "size"]:
elif name in IMPLEMENTED_TENSOR_PROPERTIES:
return getattr(self, name)
elif name in method_name_to_builtin_fn:
# TODO: backward, gradient
from .callable import BuiltinVariable

builtin_fn = method_name_to_builtin_fn[name]

return BuiltinVariable(
builtin_fn, self.graph, DanglingTracker()
).bind(self, name)
elif name in paddle_tensor_methods:
from .callable import TensorFunctionVariable

Expand Down
24 changes: 22 additions & 2 deletions tests/test_18_tensor_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,21 @@ def tensor_method_passed_by_user(a: paddle.Tensor, func: paddle.Tensor):


def tensor_method_property(a: paddle.Tensor, b: paddle.Tensor):
return a @ b.T + len(a.shape) + b.size + a.ndim
return (
a.name,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

突然想到一个问题,这里 a.name 应该没啥问题,但 (a + b).name 是通过 infer meta 计算的,这里应该不对的

a.name 在中间节点应该 break graph 的

按照这个思路可以再看看其他几个是否有类似的问题

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

当获取中间变量的name,如果不打断是这种分别生成的

- infer_meta_variable_tmp_0
+ eager_tmp_2

如果打断是这种序号会持续累加的

AssertionError: 'eager_tmp_2' != 'eager_tmp_3'
- eager_tmp_2
?           ^
+ eager_tmp_3
?           ^

所以可能这里的测试case是不能加进去的

另外这里我是不是只需要通过是否以infer_meta_variable_tmp开头来判断是否是中间变量?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

另外这里我是不是只需要通过是否以infer_meta_variable_tmp开头来判断是否是中间变量?

self.value == None 即是中间变量,但不要直接判断,将其封装成一个函数,比如 is_leaf(不等于的情况),或者其他名字

str(a.place),
a.persistable,
a.dtype,
a.type,
a.is_tensor(),
a.clear_gradient(),
a @ b.T + len(a.shape) + b.size + a.ndim + a.dim() + a.rank(),
)


def middle_tensor_name(a: paddle.Tensor, b: paddle.Tensor):
c = a + b
return c.name


class TestTensorMethod(TestCaseBase):
Expand All @@ -47,9 +61,15 @@ def test_tensor_method_passed_by_user(self):
self.assert_results(tensor_method_passed_by_user, x, y.add)

def test_tensor_method_property(self):
x = paddle.rand([42, 24], dtype='float64')
y = paddle.rand([42, 24], dtype='float32')
self.assert_results(tensor_method_property, x, y)

@unittest.skip("TODO: dynamic tensor name is different")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个中间变量的问题可以在下一个 PR 处理~

def test_middle_tensor_name(self):
x = paddle.rand([42, 24])
y = paddle.rand([42, 24])
self.assert_results(tensor_method_property, x, y)
self.assert_results(middle_tensor_name, x, y)


if __name__ == "__main__":
Expand Down