Skip to content

Commit c62c79f

Browse files
authored
【Hackathon 7th Fundable Projects 2 No.89】 [fluid_ops] custom operator (PaddlePaddle#68540)
* Fix * Fix
1 parent 5df37b6 commit c62c79f

File tree

4 files changed

+398
-34
lines changed

4 files changed

+398
-34
lines changed

paddle/fluid/operators/custom_device_common_op_registry.cc

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1316,40 +1316,6 @@ void FeedDenseTensorKernel(const Context& dev_ctx,
13161316

13171317
void RegisterCustomDeviceCommonKernel(const std::string& dev_type) {
13181318
#ifdef PADDLE_WITH_CUSTOM_DEVICE
1319-
auto device_type = dev_type.c_str();
1320-
REGISTER_OP_CUSTOM_DEVICE_KERNEL(
1321-
c_concat,
1322-
device_type,
1323-
paddle::operators::CConcatOpCustomDeviceKernel<phi::CustomContext, float>,
1324-
paddle::operators::CConcatOpCustomDeviceKernel<phi::CustomContext,
1325-
phi::dtype::float16>,
1326-
paddle::operators::CConcatOpCustomDeviceKernel<phi::CustomContext,
1327-
phi::dtype::bfloat16>);
1328-
REGISTER_OP_CUSTOM_DEVICE_KERNEL(
1329-
c_softmax_with_cross_entropy,
1330-
device_type,
1331-
paddle::operators::CSoftmaxWithCrossEntropyOpCustomDeviceKernel<
1332-
phi::CustomContext,
1333-
float>,
1334-
paddle::operators::CSoftmaxWithCrossEntropyOpCustomDeviceKernel<
1335-
phi::CustomContext,
1336-
double>,
1337-
paddle::operators::CSoftmaxWithCrossEntropyOpCustomDeviceKernel<
1338-
phi::CustomContext,
1339-
phi::dtype::float16>) {}
1340-
1341-
REGISTER_OP_CUSTOM_DEVICE_KERNEL(
1342-
c_softmax_with_cross_entropy_grad,
1343-
device_type,
1344-
paddle::operators::CSoftmaxWithCrossEntropyGradCustomDeviceKernel<
1345-
phi::CustomContext,
1346-
float>,
1347-
paddle::operators::CSoftmaxWithCrossEntropyGradCustomDeviceKernel<
1348-
phi::CustomContext,
1349-
double>,
1350-
paddle::operators::CSoftmaxWithCrossEntropyGradCustomDeviceKernel<
1351-
phi::CustomContext,
1352-
phi::dtype::float16>) {}
13531319

13541320
#endif
13551321
}
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/api/backward/backward_api.h"
16+
#include "paddle/phi/api/include/api.h"
17+
#include "paddle/phi/backends/all_context.h"
18+
#include "paddle/phi/backends/device_manager.h"
19+
#include "paddle/phi/core/distributed/collective/process_group.h"
20+
#include "paddle/phi/core/distributed/comm_context_manager.h"
21+
#include "paddle/phi/core/distributed/xccl_comm_context.h"
22+
#include "paddle/phi/core/kernel_registry.h"
23+
#include "paddle/phi/core/tensor_utils.h"
24+
#ifdef PADDLE_WITH_CUSTOM_DEVICE
25+
namespace phi {
26+
27+
template <typename T, typename Context>
28+
void CConcatKernel(const Context& dev_ctx,
29+
const DenseTensor& x_in,
30+
int rank,
31+
int nranks,
32+
int ring_id UNUSED,
33+
bool use_calc_stream UNUSED,
34+
bool use_model_parallel UNUSED,
35+
DenseTensor* out) {
36+
auto x = &x_in;
37+
int rid = ring_id;
38+
auto place = dev_ctx.GetPlace();
39+
40+
PADDLE_ENFORCE_GE(rank,
41+
0,
42+
common::errors::PreconditionNotMet(
43+
"The value of rank (%d) for c_concat must be "
44+
"greater than or equal to 0.",
45+
rank));
46+
PADDLE_ENFORCE_GE(nranks,
47+
2,
48+
common::errors::PreconditionNotMet(
49+
"The value of nranks (%d) for c_concat must be "
50+
"greater than or equal to 2.",
51+
nranks));
52+
PADDLE_ENFORCE_LT(rank,
53+
nranks,
54+
common::errors::PreconditionNotMet(
55+
"The value of rank (%d) for c_concat must be "
56+
"less than that of nranks (%d).",
57+
rank,
58+
nranks));
59+
60+
phi::DenseTensor temp_out;
61+
phi::DDim temp_out_dims = x->dims();
62+
temp_out_dims[0] *= nranks;
63+
temp_out.Resize(temp_out_dims);
64+
dev_ctx.template Alloc<T>(&temp_out);
65+
66+
auto map = distributed::ProcessGroupMapFromGid::getInstance();
67+
if (map->has(rid)) {
68+
// Use ProcessGroup
69+
distributed::ProcessGroup* pg = map->get(rid);
70+
std::vector<phi::DenseTensor> in_tensor;
71+
std::vector<phi::DenseTensor> out_tensor;
72+
in_tensor.push_back(*x);
73+
out_tensor.push_back(temp_out);
74+
auto task = pg->AllGather(in_tensor, out_tensor);
75+
task->Wait();
76+
} else {
77+
auto comm = reinterpret_cast<phi::distributed::XCCLCommContext*>(
78+
phi::distributed::CommContextManager::GetInstance().Get(
79+
std::to_string(rid)));
80+
PADDLE_ENFORCE_EQ(
81+
nranks,
82+
comm->GetSize(),
83+
common::errors::InvalidArgument(
84+
"nranks: %s should equal to %s", nranks, comm->GetSize()));
85+
86+
int64_t send_numel = x->numel();
87+
const T* send_buff = x->data<T>();
88+
T* recv_buff = temp_out.data<T>();
89+
// should ExecutionContext for calc stream.
90+
auto& stream = *dev_ctx.GetStream();
91+
phi::DeviceManager::CCLAllGather(
92+
place.GetDeviceType(),
93+
reinterpret_cast<void*>(const_cast<T*>(send_buff)),
94+
recv_buff,
95+
send_numel,
96+
x->dtype(),
97+
comm->GetXcclComm(),
98+
stream);
99+
}
100+
std::vector<phi::DenseTensor> inputs;
101+
int axis = x->dims().size() - 1;
102+
auto out_dims = x->dims();
103+
out_dims[out_dims.size() - 1] *= nranks;
104+
int rows_per_tensor = x->dims()[0];
105+
int offset = 0;
106+
for (int i = 0; i < nranks; i++) {
107+
phi::DenseTensor temp = temp_out.Slice(offset, offset + rows_per_tensor);
108+
inputs.emplace_back(temp);
109+
offset += rows_per_tensor;
110+
}
111+
112+
out->Resize(out_dims);
113+
std::vector<paddle::Tensor> inputs_t(inputs.size());
114+
for (size_t i = 0; i < inputs.size(); i++) {
115+
auto t = std::make_shared<phi::DenseTensor>();
116+
t->ShareDataWith(inputs[i]);
117+
inputs_t[i].set_impl(t);
118+
}
119+
auto output = paddle::experimental::concat(inputs_t, axis);
120+
out->ShareDataWith(*reinterpret_cast<phi::DenseTensor*>(output.impl().get()));
121+
}
122+
} // namespace phi
123+
124+
PD_REGISTER_KERNEL(c_concat,
125+
Custom,
126+
ALL_LAYOUT,
127+
phi::CConcatKernel,
128+
float,
129+
phi::dtype::float16,
130+
phi::dtype::bfloat16) {}
131+
#endif
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/api/backward/backward_api.h"
16+
#include "paddle/phi/api/include/api.h"
17+
#include "paddle/phi/backends/all_context.h"
18+
#include "paddle/phi/backends/device_manager.h"
19+
#include "paddle/phi/core/distributed/collective/process_group.h"
20+
#include "paddle/phi/core/distributed/comm_context_manager.h"
21+
#include "paddle/phi/core/distributed/xccl_comm_context.h"
22+
#include "paddle/phi/core/kernel_registry.h"
23+
#include "paddle/phi/core/tensor_utils.h"
24+
#include "paddle/phi/kernels/funcs/axis_utils.h"
25+
#ifdef PADDLE_WITH_CUSTOM_DEVICE
26+
namespace phi {
27+
28+
template <typename T, typename Context>
29+
void CSoftmaxWithEntropyGradKernel(const Context& dev_ctx,
30+
const DenseTensor& softmax_in,
31+
const DenseTensor& label_in,
32+
const DenseTensor& loss_grad_in,
33+
int64_t ignore_index,
34+
int ring_id,
35+
int rank,
36+
int nranks,
37+
DenseTensor* logits_grad) {
38+
const phi::DenseTensor* labels = &label_in;
39+
const phi::DenseTensor* loss_grad = &loss_grad_in;
40+
const phi::DenseTensor* softmax = &softmax_in;
41+
phi::DenseTensor* logit_grad = logits_grad;
42+
43+
if (logit_grad != softmax) {
44+
phi::Copy(dev_ctx, *softmax, dev_ctx.GetPlace(), false, logit_grad);
45+
}
46+
const auto softmax_dims = softmax->dims();
47+
const int axis = softmax_dims.size() - 1;
48+
const int N = phi::funcs::SizeToAxis(axis, softmax_dims);
49+
const int D = phi::funcs::SizeFromAxis(axis, softmax_dims);
50+
const auto& label_type = labels->dtype();
51+
52+
if (label_type == phi::DataType::INT32 ||
53+
label_type == phi::DataType::INT64) {
54+
auto logit_grad_t = std::make_shared<phi::DenseTensor>();
55+
logit_grad_t->ShareDataWith(*logit_grad).Resize({N, D});
56+
auto loss_grad_t = std::make_shared<phi::DenseTensor>();
57+
loss_grad_t->ShareDataWith(*loss_grad).Resize({N});
58+
auto labels_1d = std::make_shared<phi::DenseTensor>();
59+
labels_1d->ShareDataWith(*labels).Resize({N});
60+
paddle::Tensor logits_grad_tensor(logit_grad_t),
61+
loss_grad_tensor(loss_grad_t), labels_1d_tensor(labels_1d);
62+
63+
auto labels_1d_not_equal_ignore = paddle::experimental::reshape(
64+
paddle::experimental::not_equal(
65+
labels_1d_tensor,
66+
paddle::experimental::full_like(labels_1d_tensor,
67+
ignore_index,
68+
labels_1d_tensor.dtype(),
69+
labels_1d_tensor.place())),
70+
{N, 1});
71+
auto start_index_tensor =
72+
paddle::experimental::full_like(labels_1d_tensor,
73+
rank * D,
74+
labels_1d_tensor.dtype(),
75+
labels_1d_tensor.place());
76+
77+
auto logits_grad_out_tensor1 = paddle::experimental::subtract(
78+
paddle::experimental::multiply(
79+
logits_grad_tensor,
80+
paddle::experimental::cast(labels_1d_not_equal_ignore,
81+
logits_grad_tensor.dtype())),
82+
paddle::experimental::cast(
83+
paddle::experimental::one_hot(
84+
paddle::experimental::subtract(labels_1d_tensor,
85+
start_index_tensor),
86+
D),
87+
logits_grad_tensor.dtype()));
88+
89+
auto logits_grad_out_tensor2 = paddle::experimental::multiply(
90+
logits_grad_out_tensor1,
91+
paddle::experimental::reshape(loss_grad_tensor, {N, 1}));
92+
logit_grad
93+
->ShareDataWith(*reinterpret_cast<phi::DenseTensor*>(
94+
logits_grad_out_tensor2.impl().get()))
95+
.Resize(softmax_dims);
96+
} else {
97+
PADDLE_THROW(common::errors::Unavailable(
98+
"CustomDevice c_softmax_with_cross_entropy_grad "
99+
"label_type only support int32/int64"));
100+
}
101+
}
102+
} // namespace phi
103+
104+
PD_REGISTER_KERNEL(c_softmax_with_cross_entropy_grad,
105+
Custom,
106+
ALL_LAYOUT,
107+
phi::CSoftmaxWithEntropyGradKernel,
108+
float,
109+
double,
110+
phi::dtype::float16) {}
111+
#endif

0 commit comments

Comments
 (0)