Skip to content

Commit 4fc749a

Browse files
committed
Fix
1 parent 5b0034b commit 4fc749a

File tree

1 file changed

+119
-0
lines changed

1 file changed

+119
-0
lines changed
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/collective/c_concat_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
class CConcatOp : public framework::OperatorWithKernel {
21+
public:
22+
using framework::OperatorWithKernel::OperatorWithKernel;
23+
24+
void InferShape(framework::InferShapeContext* ctx) const override {
25+
PADDLE_ENFORCE_EQ(ctx->HasInput("X"),
26+
true,
27+
phi::errors::PreconditionNotMet(
28+
"Input 'X' of c_concat must be provided."));
29+
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"),
30+
true,
31+
phi::errors::PreconditionNotMet(
32+
"Output 'Out' of c_concat must be provided."));
33+
int nranks = ctx->Attrs().Get<int>("nranks");
34+
int rank = ctx->Attrs().Get<int>("rank");
35+
int ring_id = ctx->Attrs().Get<int>("ring_id");
36+
PADDLE_ENFORCE_GE(
37+
nranks,
38+
2,
39+
common::errors::InvalidArgument("The number of ranks (%d) for c_concat "
40+
"must be greater than 1.",
41+
nranks));
42+
PADDLE_ENFORCE_GE(
43+
ring_id,
44+
0,
45+
common::errors::InvalidArgument(
46+
"The ring_id (%d) for c_concat must be non-negative.", ring_id));
47+
PADDLE_ENFORCE_GE(
48+
rank,
49+
0,
50+
common::errors::InvalidArgument(
51+
"The rank (%d) for c_concat must be non-negative.", rank));
52+
PADDLE_ENFORCE_LT(rank,
53+
nranks,
54+
common::errors::InvalidArgument(
55+
"The value of rank (%d) for c_concat must "
56+
"be less than that of nranks.",
57+
rank,
58+
nranks));
59+
60+
phi::DDim dim = ctx->GetInputDim("X");
61+
dim[dim.size() - 1] = dim[dim.size() - 1] * nranks;
62+
if (dim[dim.size() - 1] < 0) dim[dim.size() - 1] = -1;
63+
ctx->SetOutputDim("Out", dim);
64+
}
65+
66+
protected:
67+
phi::KernelKey GetExpectedKernelType(
68+
const framework::ExecutionContext& ctx) const override {
69+
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
70+
ctx.GetPlace());
71+
}
72+
};
73+
74+
template <typename T>
75+
class CConcatOpGradMaker : public framework::SingleGradOpMaker<T> {
76+
public:
77+
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
78+
79+
protected:
80+
void Apply(GradOpPtr<T> retv) const override {
81+
retv->SetType("c_split");
82+
retv->SetInput("X", this->OutputGrad("Out"));
83+
retv->SetOutput("Out", this->InputGrad("X"));
84+
retv->SetAttrMap(this->Attrs());
85+
}
86+
};
87+
88+
class CConcatOpMaker : public framework::OpProtoAndCheckerMaker {
89+
public:
90+
void Make() override {
91+
AddInput("X", "(Tensor) tensor to be concated.");
92+
AddOutput("Out", "(Tensor) the result of concat.");
93+
AddAttr<int>("rank", "(int default 0) rank id.").SetDefault(0);
94+
AddAttr<int>("nranks", "(int default 1) number of ranks.").SetDefault(1);
95+
AddAttr<int>("ring_id", "(int default 0) ring id.").SetDefault(0);
96+
AddAttr<bool>(
97+
"use_calc_stream",
98+
"(bool default true) eject CUDA operations to calculation stream.")
99+
.SetDefault(true);
100+
AddAttr<bool>("use_model_parallel",
101+
"(bool default true) use this op with model parallel.")
102+
.SetDefault(true);
103+
AddComment(R"DOC(
104+
CConcat Operator
105+
AllGather the tensors on different trainers and concat them along the last dimension.
106+
)DOC");
107+
}
108+
};
109+
110+
} // namespace operators
111+
} // namespace paddle
112+
113+
namespace ops = paddle::operators;
114+
115+
REGISTER_OPERATOR(c_concat,
116+
ops::CConcatOp,
117+
ops::CConcatOpGradMaker<paddle::framework::OpDesc>,
118+
ops::CConcatOpGradMaker<paddle::imperative::OpBase>,
119+
ops::CConcatOpMaker);

0 commit comments

Comments
 (0)