Skip to content

Conversation

@megemini
Copy link
Contributor

PR types

New features

PR changes

APIs

Description

【Hackathon 5th No.31】为 Paddle 新增 column_stack / row_stack / dstack / hstack / vstack API

RFC:PaddlePaddle/community#684

涉及文件:

  • python/paddle/__init__.py 将 API 暴露出来
  • python/paddle/tensor/__init__.py 将 API 暴露出来
  • python/paddle/tensor/manipulation.py 实现 API
  • test/legacy_test/test_stack_extension_api.py 单元测试

请评审 ~

@paddle-bot
Copy link

paddle-bot bot commented Nov 19, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@megemini
Copy link
Contributor Author

@zhwesky2010

PR-CI-Coverage 中好像是有单测被删了?不清楚具体什么情况 ~

请评审!

else:
arrays.append(tensor)

return paddle.hstack(arrays, name=name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个和torch的计算结果是一样的吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

借用一下 torch 官方的例子:

In [18]: import paddle

In [19]: import torch

In [20]: a = torch.tensor([1, 2, 3])
    ...: b = torch.tensor([4, 5, 6])
    ...: torch.column_stack((a, b))
Out[20]: 
tensor([[1, 4],
        [2, 5],
        [3, 6]])

In [21]: a = torch.arange(5)
    ...: b = torch.arange(10).reshape(5, 2)
    ...: torch.column_stack((a, b, b))
Out[21]: 
tensor([[0, 0, 1, 0, 1],
        [1, 2, 3, 2, 3],
        [2, 4, 5, 4, 5],
        [3, 6, 7, 6, 7],
        [4, 8, 9, 8, 9]])

In [22]: x = paddle.to_tensor([1, 2, 3])
    ...: y = paddle.to_tensor([4, 5, 6])
    ...: paddle.column_stack((x, y))
Out[22]: 
Tensor(shape=[3, 2], dtype=int64, place=Place(gpu:0), stop_gradient=True,
       [[1, 4],
        [2, 5],
        [3, 6]])

In [23]: x = paddle.arange(5)
    ...: y = paddle.arange(10).reshape((5, 2))
    ...: paddle.column_stack((x, y, y))
Out[23]: 
Tensor(shape=[5, 5], dtype=int64, place=Place(gpu:0), stop_gradient=True,
       [[0, 0, 1, 0, 1],
        [1, 2, 3, 2, 3],
        [2, 4, 5, 4, 5],
        [3, 6, 7, 6, 7],
        [4, 8, 9, 8, 9]])

应该一样吧 ~

stack 没有 atleast_xd 的那种输入问题,stack 的输入只有一个 🤗 ~

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为我看你这个计算逻辑和torch有点区别:

def column_stack(tensors: TensorSequenceType) -> TensorLikeType:
    aligned_tensors = tuple(
        x if x.ndim > 1 else x.reshape((x.numel(), 1)) for x in tensors
    )
    return cat(aligned_tensors, 1)

如果有不同的计算逻辑,需要说明更合理性。

Copy link
Contributor Author

@megemini megemini Nov 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是指:

torch 的第二行 return cat(aligned_tensors, 1)

我这里用的 return paddle.hstack(arrays, name=name)

hstack 确实对于 ndim = 0 有特殊处理,但是实际上这里输入的 ndim 一定是大于 0 的,因此,与 return cat(aligned_tensors, 1) 是一样的啊 ~

那我还是改一下吧 ... ... 😅


p.s. 想起来了,当时用 hstack 而不是 concat 是因为: column_stack 和 row_stack 其实是 hstack 和 vstack 对等的实现,row_stack 用的 vstack,所以 column_stack 用的 hstack ~ 不过用 hstack 确实可能存在性能损失,已修改 ~ 👍

# the data feeded should NOT be a Tensor
feed[name] = input

out = func_paddle(x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

静态图也测下反向吧, paddle.static.append_backward(out) 可以创建反向,然后fetch相应的x的梯度

@megemini
Copy link
Contributor Author

megemini commented Nov 23, 2023

@zhwesky2010

静态图使用 @test_with_pir_api 在 pir 下出错,在传统 static 下没有问题 ~

初步定位,好像是 pir 中 OpResult 的 hash 问题导致的,这个是已知问题吗?

> /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages/paddle/autograd/ir_backward.py(690)calc_gradient_helper()
-> update_no_grad_set_by_stopgradient(block, no_grad_set)
(Pdb) l
685  
686         # check all inputs and outputs in the same block
687         check_all_puts(block, inputs, outputs)
688         # update no_grad_set if some value stop_gradient=True
689         pdb.set_trace()
690  ->     update_no_grad_set_by_stopgradient(block, no_grad_set)
691         complete_outputs, _, backward_ops = prepare_grad_outputs(
692             grad_outputs, outputs, state
693         )
694  
695         inputs_set = set(inputs)
(Pdb) n
> /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages/paddle/autograd/ir_backward.py(691)calc_gradient_helper()
-> complete_outputs, _, backward_ops = prepare_grad_outputs(
(Pdb) p inputs
[<paddle.base.libpaddle.pir.OpResult object at 0x7f1c734449f0>]
(Pdb) p outputs
[<paddle.base.libpaddle.pir.OpResult object at 0x7f1c73444af0>]
(Pdb) p inputs[0] in no_grad_set
True
(Pdb) p outputs[0] in no_grad_set
True
(Pdb) p inputs[0].__hash__()
70853872
(Pdb) p outputs[0].__hash__()
69594512
(Pdb) p [[a, a.__hash__()] for a in no_grad_set]
[[<paddle.base.libpaddle.pir.OpResult object at 0x7f1c73446db0>, 60682464], [<paddle.base.libpaddle.pir.OpResult object at 0x7f1c73446e70>, 71025312], [<paddle.base.libpaddle.pir.OpResult object at 0x7f1c73446ef0>, 71724384], [<paddle.base.libpaddle.pir.OpResult object at 0x7f1c73446d30>, 71716080], [<paddle.base.libpaddle.pir.OpResult object at 0x7f1c73446d70>, 71720208], [<paddle.base.libpaddle.pir.OpResult object at 0x7f1c73446e30>, 71722000], [<paddle.base.libpaddle.pir.OpResult object at 0x7f1c73446df0>, 60682448], [<paddle.base.libpaddle.pir.OpResult object at 0x7f1c73446eb0>, 70853872], [<paddle.base.libpaddle.pir.OpResult object at 0x7f1c73446f30>, 69594512]]
(Pdb) 

可以看到,这里一个输入一个输出:

(Pdb) p inputs
[<paddle.base.libpaddle.pir.OpResult object at 0x7f1c734449f0>]
(Pdb) p outputs
[<paddle.base.libpaddle.pir.OpResult object at 0x7f1c73444af0>]

其中输入的 hash 和 输出的 hash 如下:

(Pdb) p inputs[0].__hash__()
70853872
(Pdb) p outputs[0].__hash__()
69594512

再看 no_grad_set

(Pdb) p [[a, a.__hash__()] for a in no_grad_set]
[[<paddle.base.libpaddle.pir.OpResult object at 0x7f1c73446db0>, 60682464], [<paddle.base.libpaddle.pir.OpResult object at 0x7f1c73446e70>, 71025312], [<paddle.base.libpaddle.pir.OpResult object at 0x7f1c73446ef0>, 71724384], [<paddle.base.libpaddle.pir.OpResult object at 0x7f1c73446d30>, 71716080], [<paddle.base.libpaddle.pir.OpResult object at 0x7f1c73446d70>, 71720208], [<paddle.base.libpaddle.pir.OpResult object at 0x7f1c73446e30>, 71722000], [<paddle.base.libpaddle.pir.OpResult object at 0x7f1c73446df0>, 60682448], [<paddle.base.libpaddle.pir.OpResult object at 0x7f1c73446eb0>, 70853872], [<paddle.base.libpaddle.pir.OpResult object at 0x7f1c73446f30>, 69594512]]

其中也有 hash 与输入输出相同的项目:

  • no_grad_set: [<paddle.base.libpaddle.pir.OpResult object at 0x7f1c73446eb0>, 70853872]
  • 对应 input: [<paddle.base.libpaddle.pir.OpResult object at 0x7f1c734449f0>], 70853872

以上两者 id 不同,应该是不一样的 ~

  • no_grad_set: [<paddle.base.libpaddle.pir.OpResult object at 0x7f1c73446f30>, 69594512]
  • 对应 output: [<paddle.base.libpaddle.pir.OpResult object at 0x7f1c73444af0>] 69594512

以上两者 id 不同,应该是不一样的 ~

在后面判断是否在集合中的时候,由于 hash 相同而判断通过,但实际两者应该是不同的?

这就导致梯度无法回传 (可以看测试代码,两者实际都已经 stop_gradient = False,理应可以梯度回传)~

如果手动改变后面对应的 input_grad_stopgradients 参数为 [[False]] ,则此 pir 下的静态图测试用例可以跑通 ~

还请帮忙确认一下 ~


补充一下,上面的调试是用:

  if paddle.framework.in_pir_mode():
      grads = paddle.autograd.ir_backward.grad(y, [out])
      out_grad = grads[0]
      fetch_list.append(out_grad)
  else:
      paddle.static.append_backward(y)
      out_grad = out.grad_name
      fetch_list.append(out_grad)

这样的回传方式测试的 ~ pir 下好像是要用 paddle.autograd.ir_backward.gradpaddle.static.append_backward 在 pir 下也会报错 ~

不过,这个 paddle.autograd.ir_backward.grad 的文档里面又写的 **This API is ONLY available in imperative mode.** ... ...

还请帮忙指导一下具体要怎么处理 🙏🙏🙏 ~

@megemini megemini requested a review from zhwesky2010 November 23, 2023 11:21
@megemini
Copy link
Contributor Author

block.refresh_stopgradient() 好像有点问题?refresh 之后输入输出的 stop_gradient 从 False 变成 True 了?

> /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages/paddle/autograd/ir_backward.py(685)calc_gradient_helper()
-> block.refresh_stopgradient()
(Pdb) l
680  
681     def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set):
682         block = outputs[0].get_defining_op().get_parent_block()
683  
684         pdb.set_trace()
685  ->     block.refresh_stopgradient()
686  
687         #pdb.set_trace()
688         state = State(block.program)
689  
690         # check all inputs and outputs in the same block
(Pdb) p inputs
[<paddle.base.libpaddle.pir.OpResult object at 0x7fc25fb0de70>]
(Pdb) p inputs[0].stop_gradient
False
(Pdb) p outputs
[<paddle.base.libpaddle.pir.OpResult object at 0x7fc25fb0c2f0>]
(Pdb) p outputs[0].stop_gradient
False
(Pdb) s
> /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages/paddle/autograd/ir_backward.py(688)calc_gradient_helper()
-> state = State(block.program)
(Pdb) p inputs
[<paddle.base.libpaddle.pir.OpResult object at 0x7fc25fb0de70>]
(Pdb) p inputs[0].stop_gradient
True
(Pdb) p outputs
[<paddle.base.libpaddle.pir.OpResult object at 0x7fc25fb0c2f0>]
(Pdb) p outputs[0].stop_gradient
True

@megemini
Copy link
Contributor Author

@zhwesky2010 #59365 PR 里面我把问题重新整理了一下 ~

目前看依赖的底层算子好像有点问题?

@paddle-ci-bot
Copy link

paddle-ci-bot bot commented Dec 3, 2023

Sorry to inform you that b8f6da9's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

@changeyoung98
Copy link
Contributor

关于stop gradient refresh关闭的pr已经合入:#59579

@megemini
Copy link
Contributor Author

megemini commented Dec 6, 2023

@zhwesky2010 @changeyoung98

CI 中的 PR-CI-Windows-Inference 出错了 ~

好像是 pir 中的 concat 的梯度回传有问题:

2023-12-06 18:20:33 ERROR: test_dtype (test_stack_extension_api.TestVStack)
2023-12-06 18:20:33 ----------------------------------------------------------------------
2023-12-06 18:20:33 Traceback (most recent call last):
2023-12-06 18:20:33   File "C:\home\workspace\Paddle\build\test\legacy_test\test_stack_extension_api.py", line 304, in test_dtype
2023-12-06 18:20:33     self._test_all(
2023-12-06 18:20:33   File "C:\home\workspace\Paddle\build\test\legacy_test\test_stack_extension_api.py", line 189, in _test_all
2023-12-06 18:20:33     self._test_static_api(self.func_paddle, self.func_numpy, *args)
2023-12-06 18:20:33   File "C:\home\workspace\Paddle\build\python\paddle\pir_utils.py", line 115, in impl
2023-12-06 18:20:33     func(*args, **kwargs)
2023-12-06 18:20:33   File "C:\home\workspace\Paddle\build\test\legacy_test\test_stack_extension_api.py", line 133, in _test_static_api
2023-12-06 18:20:33     res, res_grad = exe.run(
2023-12-06 18:20:33   File "C:\home\workspace\Paddle\build\python\paddle\base\executor.py", line 1720, in run
2023-12-06 18:20:33     res = self._run_pir_impl(
2023-12-06 18:20:33   File "C:\home\workspace\Paddle\build\python\paddle\base\executor.py", line 2027, in _run_pir_impl
2023-12-06 18:20:33     program, new_exe = self._executor_cache.get_pir_program_and_executor(
2023-12-06 18:20:33   File "C:\home\workspace\Paddle\build\python\paddle\base\executor.py", line 1085, in get_pir_program_and_executor
2023-12-06 18:20:33     return self._get_cached_program_and_executor_pir_mode(
2023-12-06 18:20:33   File "C:\home\workspace\Paddle\build\python\paddle\base\executor.py", line 1114, in _get_pir_program_and_executor
2023-12-06 18:20:33     new_exe = _StandaloneExecutor(place, plan, scope)
2023-12-06 18:20:33   File "C:\home\workspace\Paddle\build\python\paddle\base\executor.py", line 813, in __init__
2023-12-06 18:20:33     self._new_exe = self._create_new_executor()
2023-12-06 18:20:33   File "C:\home\workspace\Paddle\build\python\paddle\base\executor.py", line 849, in _create_new_executor
2023-12-06 18:20:33     new_exe = core.StandaloneExecutor(self._place, self._plan, self._scope)
2023-12-06 18:20:33 RuntimeError: 
2023-12-06 18:20:33 --------------------------------------
2023-12-06 18:20:33 C++ Traceback (most recent call last):
2023-12-06 18:20:33 --------------------------------------
2023-12-06 18:20:33 Not support stack backtrace yet.
2023-12-06 18:20:33 ----------------------
2023-12-06 18:20:33 Error Message Summary:
2023-12-06 18:20:33 ----------------------
2023-12-06 18:20:33 PreconditionNotMetError: op [pd_op.concat_grad] kernel output args defs should equal op outputs
2023-12-06 18:20:33   [Hint: Expected op_item->num_results() == output_defs.size(), but received op_item->num_results():1 != output_defs.size():0.] (at ..\paddle\fluid\pir\transforms\pd_op_to_kernel_pass.cc:1284)

出错的地方是在 test_dtype :

    def test_dtype(self):
        for dtype in DTYPE_ALL:
            if dtype == 'float16' and (
                not core.is_compiled_with_cuda()
                or not core.is_float16_supported(paddle.CUDAPlace(0))
            ):
                continue

            if dtype == 'bfloat16' and (
                not core.is_compiled_with_cuda()
                or not core.is_bfloat16_supported(paddle.CUDAPlace(0))
            ):
                continue

            # 这里进行具体的测试
            self._test_all(
                generate_data([], count=1, dtype=dtype),
                dtype,
            )

参考如下的测试程序:

import numpy as np
import paddle
from paddle.pir_utils import test_with_pir_api

def vstack(x, name=None):
    arrays = paddle.atleast_2d(*x)
    if not isinstance(arrays, list):
        arrays = [arrays]

    return paddle.concat(arrays, axis=0, name=name)

@test_with_pir_api
def test_static_api():
    paddle.enable_static()

    with paddle.static.program_guard(paddle.static.Program()):
        x = paddle.static.data('x', (1, 2), 'float64')
        y = paddle.static.data('y', (1, 2), 'float64')

        x.stop_gradient = False
        y.stop_gradient = False

        feed = {'x': np.random.rand(1, 2), 'y': np.random.rand(1, 2)}

        out = vstack((x, y))
        out.stop_gradient = False

        z = out * 123

        fetch_list = [out]

        if paddle.framework.in_pir_mode():
            grads = paddle.autograd.ir_backward.grad(z, x)
            out_grad = grads[0]
            fetch_list.append(out_grad)

            exe = paddle.static.Executor()

            *res, res_grad = exe.run(feed=feed, fetch_list=fetch_list)

            print(res, x.shape, res_grad, res_grad.shape)


if __name__ == '__main__':
    test_static_api()

在 ubuntu 下可以正常运行,CI 中的 Windows-Inference 有问题?是因为 Windows-Inference 对 dtype 有什么特殊要求?

谢谢!

@zhwesky2010
Copy link
Contributor

@megemini 和Linux逻辑一般是一致的,API/OP开发没有特殊要求,可能是具体的C++逻辑 有误。最好使用windows开发机调下问题

@megemini
Copy link
Contributor Author

megemini commented Dec 9, 2023

@zhwesky2010

之前的问题应该是 windows-inference 的包对于 float16/bfloat16 支持有问题 ~

目前跳过 win32 gpu 的 float16/bfloat16 ,CI 已经通过,请帮忙看看可否 ~

请评审!谢谢!

Copy link
Contributor

@zhwesky2010 zhwesky2010 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@jeff41404 jeff41404 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@sunzhongkai588 sunzhongkai588 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@luotao1 luotao1 changed the title 【Hackathon 5th No.31】为 Paddle 新增 column_stack / row_stack / dstack / hstack / vstack API 【Hackathon 5th No.31】为 Paddle 新增 column_stack / row_stack / dstack / hstack / vstack API -part Dec 12, 2023
@luotao1 luotao1 merged commit 0bdee96 into PaddlePaddle:develop Dec 12, 2023
@luotao1
Copy link
Contributor

luotao1 commented Dec 12, 2023

@megemini 顺师傅,可以提交中文文档了

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants