Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions python/paddle/base/dygraph/math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,7 @@ def _size_(var):
def _T_(var):
if len(var.shape) == 1:
return var
perm = []
for i in range(len(var.shape)):
perm.insert(0, i)
perm = list(reversed(range(len(var.shape))))
out = _C_ops.transpose(var, perm)
return out

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ def visit_Call(self, node):
Can't convert name of function call, bacause this will affect CallTransformer.
"""
node.args = [self.visit(arg) for arg in node.args]
for keyword in node.keywords:
keyword.value = self.visit(keyword.value)
node.func = self.visit(node.func)
return node

Expand Down
30 changes: 30 additions & 0 deletions python/paddle/pir/math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,35 @@ def _size_(self):
"""
return paddle.numel(self)

@property
def _T_(self):
"""

Permute current Value with its dimensions reversed.

If `n` is the dimensions of `x` , `x.T` is equivalent to `x.transpose([n-1, n-2, ..., 0])`.

Examples:
.. code-block:: python

>>> import paddle
>>> paddle.enable_static()

>>> x = paddle.ones(shape=[2, 3, 5])
>>> x_T = x.T

>>> exe = paddle.static.Executor()
>>> x_T_np = exe.run(paddle.static.default_main_program(), fetch_list=[x_T])[0]
>>> print(x_T_np.shape)
(5, 3, 2)

"""
if len(self.shape) == 1:
return self
perm = list(reversed(range(len(self.shape))))

return _C_ops.transpose(self, perm)

def clone(self):
"""
Returns a new static Value, which is the clone of the original static
Expand Down Expand Up @@ -511,6 +540,7 @@ def value_hash(self):
('ndim', _ndim),
('astype', astype),
('size', _size_),
('T', _T_),
('clone', clone),
('clear_gradient', clear_gradient),
('append', append),
Expand Down
22 changes: 22 additions & 0 deletions test/dygraph_to_static/test_load_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,5 +71,27 @@ def func(x):
np.testing.assert_allclose(output_dy.numpy(), output_st.numpy())


class LoadInCallKwargsNet(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.extra_inputs = []

def forward(self, x):
for i in range(len(self.extra_inputs)):
x = paddle.nn.functional.linear(weight=self.extra_inputs[i].T, x=x)
return x


class TestLoadInCallKwargs(Dy2StTestBase):
@test_legacy_and_pt_and_pir
def test_name_load_nograd(self):
net = LoadInCallKwargsNet()
x = paddle.rand([10, 10])
net.extra_inputs.append(paddle.rand([10, 10]))
output_st = paddle.jit.to_static(net)(x)
output_dy = net(x)
np.testing.assert_allclose(output_dy.numpy(), output_st.numpy())


if __name__ == "__main__":
unittest.main()
10 changes: 10 additions & 0 deletions test/legacy_test/test_math_op_patch_pir.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,16 @@ def test_size(self):
(output_x,) = exe.run(main_program, fetch_list=[x.size])
self.assertEqual(output_x, 24)

def test_T(self):
with paddle.pir_utils.IrGuard():
main_program, exe, program_guard = new_program()
with program_guard:
x = paddle.assign(np.random.rand(2, 3, 4).astype("float32"))
x_T = x.T
self.assertEqual(x_T.shape, [4, 3, 2])
(output_x,) = exe.run(main_program, fetch_list=[x_T])
self.assertEqual(output_x.shape, (4, 3, 2))

def test_hash_error(self):
with paddle.pir_utils.IrGuard():
_, _, program_guard = new_program()
Expand Down