This repository was archived by the owner on Jan 24, 2024. It is now read-only.

Description
一、背景
Symbolic Translate 对于 Tensor 中的method调用,比如 a + b,我们会调用对应的 tensor variable 的魔法函数进行组网。但是 tensor 中存在一些方法调用行为不通用,比如 tensor.size 返回所有的 elements 个数(ConstVariable),并且在存在 -1 时需要进行子图打断。这类情况我们需要进行特殊处理。
二、 任务目标
任务目标:先验的考虑所有的 Tensor 的属性和方法,考察所有的属性是否『动静统一』:
-
如果 Tensor 和 Variable 有同样的接口,并且返回的值都是 Tensor,那么可以动静统一。
-
如果 Tensor 和 Variable 接口存在diff,或者是返回值不是 Tensor,而是int,比如 size, shape 等,我们考虑子图打断。
-
最后可以考虑动态图 fallback, raise NotImplementException 异常。
三、 Hint
- 查看所有的属性方法:
import paddle
a = paddle.to_tensor([1])
dir (a.__class__) # 查看所有的属性和方法
print (a.__class__.size) # 查看属性的类型,可以看到是 method / unbound fuction / property
- TensorVariable 中的代码
def __getattr__(self, name: str):
if name in paddle_tensor_methods:
from .callable import TensorMethodVariable
return TensorMethodVariable(
self, name, self.graph, tracker=GetAttrTracker(self, name)
)
elif name in ["shape", "dtype", "stop_gradient"]:
return VariableFactory.from_value(
getattr(self.meta, name),
self.graph,
tracker=GetAttrTracker(self, name),
)
elif name in ["T", "ndim", "size"]:
return getattr(self, name)
else:
raise InnerError(f"Unknown Tensor attribute: {name}")
这里的 'T', 'ndim', 'size' 就是对应的处理。然后需要在 TensorVariable下实现一个。比如 shape :
@property
def shape(self):
return ConstantVariable.wrap_literal(len(self.meta.shape))