Skip to content
This repository was archived by the owner on Jan 24, 2024. It is now read-only.
This repository was archived by the owner on Jan 24, 2024. It is now read-only.

支持对 TensorVariable 中的 property 属性进行处理 #151

@2742195759

Description

@2742195759

一、背景

Symbolic Translate 对于 Tensor 中的method调用,比如 a + b,我们会调用对应的 tensor variable 的魔法函数进行组网。但是 tensor 中存在一些方法调用行为不通用,比如 tensor.size 返回所有的 elements 个数(ConstVariable),并且在存在 -1 时需要进行子图打断。这类情况我们需要进行特殊处理。

二、 任务目标

任务目标:先验的考虑所有的 Tensor 的属性和方法,考察所有的属性是否『动静统一』:

  • 如果 Tensor 和 Variable 有同样的接口,并且返回的值都是 Tensor,那么可以动静统一。

  • 如果 Tensor 和 Variable 接口存在diff,或者是返回值不是 Tensor,而是int,比如 size, shape 等,我们考虑子图打断。

  • 最后可以考虑动态图 fallback, raise NotImplementException 异常。

三、 Hint

  1. 查看所有的属性方法:
import paddle
a = paddle.to_tensor([1])
dir (a.__class__) # 查看所有的属性和方法
print (a.__class__.size) # 查看属性的类型,可以看到是 method / unbound fuction / property
  1. TensorVariable 中的代码
    def __getattr__(self, name: str):
        if name in paddle_tensor_methods:
            from .callable import TensorMethodVariable

            return TensorMethodVariable(
                self, name, self.graph, tracker=GetAttrTracker(self, name)
            )
        elif name in ["shape", "dtype", "stop_gradient"]:
            return VariableFactory.from_value(
                getattr(self.meta, name),
                self.graph,
                tracker=GetAttrTracker(self, name),
            )
        elif name in ["T", "ndim", "size"]:
            return getattr(self, name)
        else:
            raise InnerError(f"Unknown Tensor attribute: {name}")

这里的 'T', 'ndim', 'size' 就是对应的处理。然后需要在 TensorVariable下实现一个。比如 shape :

    @property
    def shape(self):
        return ConstantVariable.wrap_literal(len(self.meta.shape))

Metadata

Metadata

Assignees

Labels

🐾 meow快乐喵喵开源活动专属认证

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions