Skip to content

Commit 4ddd410

Browse files
authored
【Hackathon 7th No.20】为 Paddle 新增 Tensor.set_ -part (#68681)
1 parent a43a11f commit 4ddd410

File tree

8 files changed

+494
-0
lines changed

8 files changed

+494
-0
lines changed

paddle/phi/infermeta/unary.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3430,6 +3430,15 @@ void PSendArrayInferMeta(const MetaTensor& x, int peer) {
34303430
"The peer (%d) for p_send op must be non-negative.", peer));
34313431
}
34323432

3433+
void SetInferMeta(const MetaTensor& x,
3434+
const std::vector<int64_t>& shape,
3435+
const std::vector<int64_t>& stride,
3436+
MetaTensor* out) {
3437+
out->set_dtype(x.dtype());
3438+
out->set_dims(common::make_ddim(shape));
3439+
out->set_strides(common::make_ddim(stride));
3440+
}
3441+
34333442
void SendV2InferMeta(const int peer, const int ring_id) {
34343443
PADDLE_ENFORCE_GE(
34353444
peer,

paddle/phi/infermeta/unary.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -719,6 +719,11 @@ void FillSplitOutDims(const MetaTensor& x,
719719
const std::vector<int64_t>& sections_vec,
720720
std::vector<MetaTensor*>* out);
721721

722+
void SetInferMeta(const MetaTensor& x,
723+
const std::vector<int64_t>& shape,
724+
const std::vector<int64_t>& stride,
725+
MetaTensor* out);
726+
722727
void SequenceSoftmaxInferMeta(const MetaTensor& x, MetaTensor* out);
723728

724729
void SplitInferMeta(const MetaTensor& x_meta,

paddle/phi/kernels/set_kernel.cc

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
#include "paddle/phi/kernels/set_kernel.h"
15+
#include "paddle/phi/core/kernel_registry.h"
16+
17+
namespace phi {
18+
19+
template <typename T, typename Context>
20+
void SetKernel(const Context& dev_ctx,
21+
const DenseTensor& x,
22+
const DenseTensor& source,
23+
const std::vector<int64_t>& dims,
24+
const std::vector<int64_t>& stride,
25+
int64_t offset,
26+
DenseTensor* out) {
27+
auto meta = out->meta();
28+
meta.dims = DDim(dims.data(), static_cast<int>(dims.size()));
29+
meta.strides = DDim(stride.data(), static_cast<int>(stride.size()));
30+
meta.offset = offset;
31+
if (x.IsSharedWith(source)) {
32+
out->set_meta(meta);
33+
} else {
34+
// reset holder to nullptr
35+
out->clear();
36+
*out = DenseTensor{source.Holder(), meta};
37+
}
38+
out->ShareInplaceVersionCounterWith(x);
39+
}
40+
41+
} // namespace phi
42+
43+
PD_REGISTER_KERNEL(set,
44+
CPU,
45+
ALL_LAYOUT,
46+
phi::SetKernel,
47+
bool,
48+
uint8_t,
49+
int8_t,
50+
int16_t,
51+
int,
52+
int64_t,
53+
float,
54+
double,
55+
phi::dtype::float16,
56+
phi::dtype::bfloat16,
57+
phi::dtype::complex<float>,
58+
phi::dtype::complex<double>) {}
59+
60+
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
61+
PD_REGISTER_KERNEL(set,
62+
GPU,
63+
ALL_LAYOUT,
64+
phi::SetKernel,
65+
bool,
66+
uint8_t,
67+
int8_t,
68+
int16_t,
69+
int,
70+
int64_t,
71+
float,
72+
double,
73+
phi::dtype::float16,
74+
phi::dtype::bfloat16,
75+
phi::dtype::complex<float>,
76+
phi::dtype::complex<double>) {}
77+
#endif

paddle/phi/kernels/set_kernel.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "paddle/phi/core/dense_tensor.h"
18+
19+
namespace phi {
20+
21+
template <typename T, typename Context>
22+
void SetKernel(const Context& dev_ctx,
23+
const DenseTensor& x,
24+
const DenseTensor& source,
25+
const std::vector<int64_t>& dims,
26+
const std::vector<int64_t>& stride,
27+
int64_t offset,
28+
DenseTensor* out);
29+
30+
} // namespace phi

paddle/phi/ops/yaml/ops.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4409,6 +4409,16 @@
44094409
backward: sequence_pool_grad
44104410
interfaces : paddle::dialect::InferSymbolicShapeInterface
44114411

4412+
- op : set
4413+
args : (Tensor x, Tensor source, int64_t[] dims = {}, int64_t[] stride = {}, int64_t offset = 0)
4414+
output : Tensor (out)
4415+
infer_meta :
4416+
func : SetInferMeta
4417+
param : [x, dims, stride]
4418+
kernel :
4419+
func : set
4420+
inplace : (x -> out)
4421+
44124422
- op : set_value_with_tensor
44134423
args : (Tensor x, Tensor values, IntArray starts, IntArray ends, IntArray steps, int64_t[] axes, int64_t[] decrease_axes, int64_t[] none_axes)
44144424
output : Tensor(out)

python/paddle/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
ones,
5353
ones_like,
5454
polar,
55+
set_,
5556
to_tensor,
5657
tril,
5758
tril_,
@@ -842,6 +843,7 @@
842843
"combinations",
843844
'signbit',
844845
'log_normal_',
846+
'set_',
845847
]
846848

847849
# this list used in math_op_patch.py for magic_method bind

python/paddle/tensor/creation.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3257,3 +3257,125 @@ def geometric_(
32573257
x.uniform_(min=float(tiny), max=float(1))
32583258
x.log_().divide_(paddle.log1p(-(probs)))
32593259
return x
3260+
3261+
3262+
@inplace_apis_in_dygraph_only
3263+
def set_(
3264+
x: paddle.Tensor,
3265+
source: paddle.Tensor | None = None,
3266+
shape: Sequence[int] | None = None,
3267+
stride: Sequence[int] | None = None,
3268+
offset: int = 0,
3269+
name: str | None = None,
3270+
) -> paddle.Tensor:
3271+
"""
3272+
set x with specified source Tensor's underlying storage, shape, stride and offset.
3273+
3274+
Note that the ``x`` will share the same data with ``source`` Tensor.
3275+
3276+
Args:
3277+
x (Tensor): An arbitrary Tensor. The data type supports ``bfloat16``, ``float16``, ``float32``, ``float64``,
3278+
``bool``, ``int8``, ``int16``, ``int32``, ``int64``, ``uint8``, ``complex64`` or ``complex128``.
3279+
source (Tensor|None, optional): Define the target Tensor to use. The data type supports `bfloat16`, ``float16``,
3280+
``float32``, ``float64``, ``bool``, ``int8``, ``int16``, ``int32``, ``int64``, ``uint8``, ``complex64`` or
3281+
``complex128``. Default: None, which means to set ``x`` with an empty source tensor.
3282+
shape (list|tuple|None, optional): Define the target shape. Each element of it should be integer. Default: None,
3283+
which means it will use the specified ``source``'s shape as default value.
3284+
stride (list|tuple|None, optional): Define the target stride. Each element of it should be integer. Default: None,
3285+
and when ``shape`` is also None, it will use the specified ``source``'s stride as default value; when ``shape``
3286+
is specified, it will use the default stride corresponding to the specified ``shape``.
3287+
offset (int, optional): Define the target offset from x's holder. Default: 0.
3288+
name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
3289+
3290+
Returns:
3291+
Tensor, the Tensor with the same data type as ``x``.
3292+
3293+
Examples:
3294+
.. code-block:: python
3295+
3296+
>>> import paddle
3297+
3298+
>>> src = paddle.to_tensor([[11., 22., 33.]])
3299+
>>> src2 = paddle.to_tensor([11., 22., 33., 44., 55., 66.])
3300+
3301+
>>> x = paddle.to_tensor([1., 2., 3., 4., 5.])
3302+
>>> x.set_()
3303+
>>> print(x)
3304+
Tensor(shape=[0], dtype=float32, place=Place(cpu), stop_gradient=True,
3305+
[])
3306+
3307+
>>> x = paddle.to_tensor([1., 2., 3., 4., 5.])
3308+
>>> x.set_(src)
3309+
>>> print(x)
3310+
Tensor(shape=[1, 3], dtype=float32, place=Place(cpu), stop_gradient=True,
3311+
[[11., 22., 33.]])
3312+
3313+
>>> print(x._is_shared_buffer_with(src))
3314+
True
3315+
3316+
>>> x = paddle.to_tensor([1., 2., 3., 4., 5.])
3317+
>>> x.set_(src, shape=[2, 1])
3318+
>>> print(x)
3319+
Tensor(shape=[2, 1], dtype=float32, place=Place(cpu), stop_gradient=True,
3320+
[[11.],
3321+
[22.]])
3322+
3323+
>>> x = paddle.to_tensor([1., 2., 3., 4., 5.])
3324+
>>> x.set_(src2, shape=[3], stride=[2])
3325+
>>> print(x)
3326+
Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True,
3327+
[11., 33., 55.])
3328+
3329+
>>> x = paddle.to_tensor([1., 2., 3., 4., 5.])
3330+
>>> x.set_(src2, shape=[5], offset=4)
3331+
>>> print(x)
3332+
Tensor(shape=[5], dtype=float32, place=Place(cpu), stop_gradient=True,
3333+
[22., 33., 44., 55., 66.])
3334+
3335+
"""
3336+
if in_dynamic_mode():
3337+
# set_ doesn't have backward op so EagerUtils::CheckInplace will not be
3338+
# called in eager_generator.cc. Here to keep consistent with other inplace
3339+
# op, manually check whether x is leaf node and doesn't stop gradient.
3340+
if x.is_leaf and not x.stop_gradient:
3341+
raise ValueError(
3342+
f"(InvalidArgument) Leaf Tensor {x.name} that doesn't stop gradient can't use "
3343+
"inplace strategy."
3344+
)
3345+
if source is None:
3346+
source = paddle.empty([0], dtype=x.dtype)
3347+
shape = [0]
3348+
stride = [0]
3349+
else:
3350+
if not isinstance(source, (Variable, core.eager.Tensor)):
3351+
raise ValueError(
3352+
f"Input (source) should be paddle.Tensor but received {type(source)}"
3353+
)
3354+
check_dtype(
3355+
source.dtype,
3356+
'source',
3357+
[
3358+
'bool',
3359+
'float16',
3360+
'uint16',
3361+
'float32',
3362+
'float64',
3363+
'int8',
3364+
'int16',
3365+
'int32',
3366+
'int64',
3367+
'uint8',
3368+
'complex64',
3369+
'complex128',
3370+
],
3371+
'set',
3372+
)
3373+
if stride is None:
3374+
if shape is None:
3375+
stride = source.strides
3376+
else:
3377+
stride = paddle.empty(shape).strides
3378+
if shape is None:
3379+
shape = source.shape
3380+
3381+
return _C_ops.set_(x, source, shape, stride, offset)

0 commit comments

Comments
 (0)