Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
299525d
use operator context
jacquesqiao Jul 23, 2017
11eabf8
optimize code
jacquesqiao Jul 24, 2017
4280a60
update net infershape
jacquesqiao Jul 24, 2017
dda4881
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Jul 24, 2017
fb1b3d1
update InferShape
jacquesqiao Jul 24, 2017
081c7ca
disable override InferShape(scope) in OperatorBase
jacquesqiao Jul 24, 2017
5273c7e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Jul 25, 2017
0d693fe
change InferShapeImpl to InferShape
jacquesqiao Jul 25, 2017
bf3940b
add template to OperatorContext Input/Output
jacquesqiao Jul 26, 2017
362ba2f
merge Input InputVar, Output OutputVar
jacquesqiao Jul 26, 2017
a4bfb61
change Inputs to MultiInput
jacquesqiao Jul 26, 2017
217186e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Jul 27, 2017
1c91df3
fix conflict
jacquesqiao Jul 27, 2017
2460af4
fix MultiInput bugs and add unit test
jacquesqiao Jul 27, 2017
fb1980e
rename KernelContext to ExecutionContext
jacquesqiao Jul 27, 2017
9ff3595
clean code
jacquesqiao Jul 27, 2017
9fafc46
change InferShape to protected
jacquesqiao Jul 28, 2017
e87d253
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Jul 30, 2017
9a2640b
fix template bug
jacquesqiao Jul 30, 2017
b6764c9
refine code
jacquesqiao Jul 30, 2017
fab7737
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Jul 30, 2017
e4445d6
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Jul 31, 2017
eda5493
use InputVar instead of Input<Variable>
jacquesqiao Jul 31, 2017
5f0ed40
typo
jacquesqiao Jul 31, 2017
bd8872c
optimize code
jacquesqiao Aug 1, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions paddle/framework/net.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ class PlainNet : public Net {
* Infer all the operators' input and output variables' shapes, will be called
* before every mini-batch
*/
void InferShape(const std::shared_ptr<Scope>& scope) const override {
void InferShapeImpl(const InferShapeContext& ctx) const override {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my mind, Impl is usually a suffix for a class name, which implements an interface. Do we really need to name a function Impl here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's right, maybe just called InferShape is cool.

for (auto& op : ops_) {
op->InferShape(scope);
op->InferShape(ctx.scope_);
}
}

Expand Down
4 changes: 2 additions & 2 deletions paddle/framework/net_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ static int run_cnt = 0;

class TestOp : public OperatorBase {
public:
void InferShape(
const std::shared_ptr<framework::Scope>& scope) const override {
void InferShapeImpl(
const paddle::framework::InferShapeContext& ctx) const override {
++infer_shape_cnt;
}
void Run(const std::shared_ptr<framework::Scope>& scope,
Expand Down
4 changes: 2 additions & 2 deletions paddle/framework/op_registry_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class CosineOp : public OperatorBase {
public:
void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override {}
void InferShape(const std::shared_ptr<Scope>& scope) const override {}
void InferShapeImpl(const InferShapeContext& ctx) const override {}
};

class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
Expand All @@ -27,7 +27,7 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {

class MyTestOp : public OperatorBase {
public:
void InferShape(const std::shared_ptr<Scope>& scope) const override {}
void InferShapeImpl(const InferShapeContext& ctx) const override {}
void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override {}
};
Expand Down
4 changes: 4 additions & 0 deletions paddle/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ std::vector<std::string> OperatorBase::Outputs(const std::string& name) const {
outputs_.begin() + output_format.at(offset + 1)};
}

void OperatorBase::InferShape(const std::shared_ptr<Scope>& scope) const {
InferShapeImpl(InferShapeContext(this, scope));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does scope here have to be of type shared_ptr? It seems simpler if we can use const Scope&.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It cannot be done, because inside InferShape/Run in some operator, e.g., RNN, the developer will create a new local Scope which uses std::shared_ptr<Scope> as an argument.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里传递指针的引用确实比较confusing,能加一下注释吗?
指针的引用表示指针本身也会被改变,那改变之后,这个指针之前指向的对象怎么办呢?会有内存泄漏吗?

}

std::string OperatorBase::DebugString() const {
std::stringstream ss;
ss << "Op(" << type_ << "), inputs:(";
Expand Down
154 changes: 84 additions & 70 deletions paddle/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,9 @@ limitations under the License. */
namespace paddle {
namespace framework {

template <typename T>
struct EigenDeviceConverter;

template <>
struct EigenDeviceConverter<platform::CPUPlace> {
using EigenDeviceType = Eigen::DefaultDevice;
};

#ifndef PADDLE_ONLY_CPU
template <>
struct EigenDeviceConverter<platform::GPUPlace> {
using EigenDeviceType = Eigen::GpuDevice;
};
#endif

class OperatorBase;
class InferShapeContext;
class KernelContext;
/**
* OperatorBase has the basic element that Net will call to do computation.
* Only CreateOperator from OpRegistry will new Operator directly. User
Expand Down Expand Up @@ -84,7 +71,8 @@ class OperatorBase {

/// InferShape infer the size of Variables used by this Operator with
/// information inside scope
virtual void InferShape(const std::shared_ptr<Scope>& scope) const = 0;
virtual void InferShape(const std::shared_ptr<Scope>& scope) const final;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we use the reference of std::shared_ptr<Scope> as a parameter, and all InferShape share the same std::shared_ptr<Scope>.
Does it means that there will be only one std::shared_ptr<Scope> ? If so, why not use std::unique_ptr<Scope> instead?

virtual void InferShapeImpl(const InferShapeContext& ctx) const = 0;

/// Net will call this function to Run an op.
virtual void Run(const std::shared_ptr<Scope>& scope,
Expand All @@ -110,29 +98,32 @@ class OperatorBase {
std::shared_ptr<std::unordered_map<std::string, int>> in_out_idxs_;
};

class KernelContext {
class OperatorContext {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OperatorContext => ExecutionContext?

我们会有两个概念分别叫做 OperatorContext 和 KernelContext 的吗?如果其实没有,那么就叫 Context 或者 ExecutionContext 是不是更清楚?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ExecutionContext is better

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We actually have two contexts, one for InferShape, other for Run.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See line 35 and 36

public:
KernelContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& device_context)
: op_(*op), scope_(scope), device_context_(device_context) {}
OperatorContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope)
: op_(*op), scope_(scope) {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have to check OperatorBase* op not null first.
And we use const OperatorBase& op_ as a member, why not const std::shared_ptr<OperatorBase> op_

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we donot need to check because context will only be construct inside a op, so op will nevel be null. And so be need not to use std::shared_ptr


int InputSize() const { return static_cast<int>(op_.inputs_.size()); }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

??? why not size_t?


const Variable* Input(int index) const {
int OutputSize() const { return static_cast<int>(op_.outputs_.size()); }

const Variable* InputVar(int index) const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In OperatorBase , Input returns string, In Context, Input returns Tensor, and here is another InputVar

Can we uniform all the Input()s, and use a single API like template <typename T> input(std::string), and implements three types: string, Tensor, Variable.

This is much simplier to understand. @jacquesqiao

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great suggestion, thanx!

return scope_->GetVariable(op_.inputs_[index]);
}

Variable* Output(int index) const {
Variable* OutputVar(int index) const {
return scope_->GetVariable(op_.outputs_[index]);
}

const Variable* Input(const std::string& name) const {
const Variable* InputVar(const std::string& name) const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可不可以 std::vector<const Variable*> 整个作为 Variable::Get()的类型:

typedef std::vector<Tensor*> TensorArray;
TensorArray tensors = var.Get<TensorArray>();

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still need InputVar although we have template <typename T> Input(name) ?

Isn't InputVar is Input<Variable> ?

return scope_->GetVariable(op_.Input(name));
}

const Variable* Output(const std::string& name) const {
Variable* OutputVar(const std::string& name) const {
return scope_->GetVariable(op_.Output(name));
}

const std::vector<const Variable*> Inputs(const std::string& name) const {
const std::vector<const Variable*> InputVars(const std::string& name) const {
auto names = op_.Inputs(name);
std::vector<const Variable*> res;
std::transform(
Expand All @@ -141,7 +132,7 @@ class KernelContext {
return res;
}

const std::vector<const Variable*> Outputs(const std::string& name) const {
const std::vector<const Variable*> OutputVars(const std::string& name) const {
auto names = op_.Outputs(name);
std::vector<const Variable*> res;
std::transform(
Expand All @@ -150,15 +141,80 @@ class KernelContext {
return res;
}

const Tensor& Input(int index) const {
return InputVar(index)->Get<Tensor>();
}

Tensor* Output(int index) const {
return OutputVar(index)->GetMutable<Tensor>();
}

const Tensor& Input(const std::string& name) const {
return InputVar(name)->Get<Tensor>();
}

Tensor* Output(const std::string& name) const {
return OutputVar(name)->GetMutable<Tensor>();
}

const std::vector<const Tensor*> Inputs(const std::string& name) const {
auto names = op_.Inputs(name);
std::vector<const Tensor*> res;
std::transform(names.begin(), names.end(), res.begin(),
[this](const std::string& name) {
return &scope_->GetVariable(name)->Get<Tensor>();
});
return res;
}

std::vector<const Tensor*> Outputs(const std::string& name) const {
auto names = op_.Outputs(name);
std::vector<const Tensor*> res;
std::transform(names.begin(), names.end(), res.begin(),
[this](const std::string& name) {
return scope_->GetVariable(name)->GetMutable<Tensor>();
});
return res;
}

const OperatorBase& op_;
const std::shared_ptr<Scope>& scope_;
};

class InferShapeContext : public OperatorContext {
public:
InferShapeContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope)
: OperatorContext(op, scope) {}
};

template <typename T>
struct EigenDeviceConverter;

template <>
struct EigenDeviceConverter<platform::CPUPlace> {
using EigenDeviceType = Eigen::DefaultDevice;
};

#ifndef PADDLE_ONLY_CPU
template <>
struct EigenDeviceConverter<platform::GPUPlace> {
using EigenDeviceType = Eigen::GpuDevice;
};
#endif

class KernelContext : public OperatorContext {
public:
KernelContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& device_context)
: OperatorContext(op, scope), device_context_(device_context) {}

template <typename PlaceType,
typename DeviceType =
typename EigenDeviceConverter<PlaceType>::EigenDeviceType>
DeviceType* GetEigenDevice() const;

platform::Place GetPlace() const { return device_context_.GetPlace(); }

const OperatorBase& op_;
const std::shared_ptr<Scope>& scope_;
const platform::DeviceContext& device_context_;
};

Expand All @@ -176,19 +232,6 @@ class OpKernel {
virtual ~OpKernel() {}
};

template <typename T>
struct VarToTensor {};

template <>
struct VarToTensor<Tensor*> {
Tensor* operator()(Variable* var) { return var->GetMutable<Tensor>(); }
};

template <>
struct VarToTensor<const Tensor*> {
const Tensor* operator()(Variable* var) { return &var->Get<Tensor>(); }
};

class OperatorWithKernel : public OperatorBase {
public:
struct OpKernelKey {
Expand Down Expand Up @@ -223,35 +266,6 @@ class OperatorWithKernel : public OperatorBase {
static std::unordered_map<std::string, OpKernelMap> g_all_op_kernels;
return g_all_op_kernels;
}

void InferShape(const std::shared_ptr<Scope>& scope) const final {
std::vector<const Tensor*> ins;
VarNamesToTensors(scope, inputs_, &ins);
std::vector<Tensor*> outs;
VarNamesToTensors(scope, outputs_, &outs);
InferShape(ins, outs);
};

private:
template <typename T>
void VarNamesToTensors(const std::shared_ptr<Scope>& scope,
const std::vector<std::string>& var_names,
std::vector<T>* container) const {
container->reserve(var_names.size());
VarToTensor<T> convert;
for (auto& name : var_names) {
auto var = scope->GetVariable(name);
if (var != nullptr) {
container->push_back(convert(var));
} else {
container->push_back(nullptr);
}
}
}

protected:
virtual void InferShape(const std::vector<const Tensor*>& inputs,
const std::vector<Tensor*>& outputs) const = 0;
};

} // namespace framework
Expand Down
8 changes: 4 additions & 4 deletions paddle/framework/operator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ static int op_run_num = 0;
class OpWithoutKernelTest : public OperatorBase {
public:
void Init() override { x = 1; }
void InferShape(const std::shared_ptr<Scope>& scope) const override {}
void InferShapeImpl(const framework::InferShapeContext& ctx) const override {}
void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override {
op_run_num++;
Expand Down Expand Up @@ -73,6 +73,7 @@ TEST(OperatorBase, all) {
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
scope->CreateVariable("OUT1");
ASSERT_EQ(paddle::framework::op_run_num, 0);
op->InferShape(scope);
op->Run(scope, device_context);
ASSERT_EQ(paddle::framework::op_run_num, 1);
}
Expand All @@ -97,8 +98,7 @@ static int cpu_kernel_run_num = 0;

class OpWithKernelTest : public OperatorWithKernel {
protected:
void InferShape(const std::vector<const Tensor*>& inputs,
const std::vector<Tensor*>& outputs) const override {}
void InferShapeImpl(const framework::InferShapeContext& ctx) const override {}
};

template <typename T1, typename T2>
Expand All @@ -117,7 +117,7 @@ class CPUKernelTest : public OpKernel {
class OperatorMultiInputsTest : public OperatorBase {
public:
void Init() override { x = 1; }
void InferShape(const std::shared_ptr<Scope>& scope) const override {}
void InferShapeImpl(const framework::InferShapeContext& ctx) const override {}
void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override {
ASSERT_EQ(scope->GetVariable(inputs_[0]), nullptr);
Expand Down
23 changes: 10 additions & 13 deletions paddle/operators/add_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,16 @@ namespace operators {

class AddOp : public framework::OperatorWithKernel {
protected:
void InferShape(
const std::vector<const framework::Tensor *> &inputs,
const std::vector<framework::Tensor *> &outputs) const override {
PADDLE_ENFORCE(inputs.size() == 2, "Input size of AddOp must be two");
PADDLE_ENFORCE(outputs.size() == 1, "Output size of AddOp must be one");
PADDLE_ENFORCE(
inputs[0] != nullptr && inputs[1] != nullptr && outputs[0] != nullptr,
"Inputs/Outputs of AddOp must all be set");
PADDLE_ENFORCE(inputs[0]->dims() == inputs[1]->dims(),
void InferShapeImpl(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 2, "Input size of AddOp must be two");
PADDLE_ENFORCE(ctx.OutputSize() == 1, "Output size of AddOp must be one");
PADDLE_ENFORCE(ctx.InputVar(0) != nullptr && ctx.InputVar(1) != nullptr,
"Inputs of AddOp must all be set");
PADDLE_ENFORCE(ctx.OutputVar(0) != nullptr,
"Outputs of AddOp must all be set");
PADDLE_ENFORCE(ctx.Input(0).dims() == ctx.Input(1).dims(),
"Two input of Add Op's dimension must be same.");
outputs[0]->Resize(inputs[0]->dims());
ctx.Output(0)->Resize(ctx.Input(0).dims());
}
};

Expand All @@ -52,9 +51,7 @@ The equation is: Out = X + Y

class AddOpGrad : public framework::OperatorWithKernel {
protected:
void InferShape(
const std::vector<const framework::Tensor *> &inputs,
const std::vector<framework::Tensor *> &outputs) const override {}
void InferShapeImpl(const framework::InferShapeContext &ctx) const override {}
std::string DebugString() const override {
LOG(INFO) << "AddOpGrad";
return "";
Expand Down
6 changes: 3 additions & 3 deletions paddle/operators/add_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ template <typename Place, typename T>
class AddKernel : public framework::OpKernel {
public:
void Compute(const framework::KernelContext& context) const override {
auto input0 = context.Input(0)->Get<framework::Tensor>();
auto input1 = context.Input(1)->Get<framework::Tensor>();
auto* output = context.Output(0)->GetMutable<framework::Tensor>();
auto input0 = context.Input(0);
auto input1 = context.Input(1);
auto* output = context.Output(0);

output->mutable_data<T>(context.GetPlace());

Expand Down
18 changes: 8 additions & 10 deletions paddle/operators/cross_entropy_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,19 @@ namespace operators {

class OnehotCrossEntropyOp : public framework::OperatorWithKernel {
protected:
void InferShape(
const std::vector<const framework::Tensor *> &inputs,
const std::vector<framework::Tensor *> &outputs) const override {
PADDLE_ENFORCE(inputs.size() == 2,
void InferShapeImpl(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 2,
"Input size of OnehotCrossEntropyOp must be two");
PADDLE_ENFORCE(outputs.size() == 1,
PADDLE_ENFORCE(ctx.OutputSize() == 1,
"Output size of OnehotCrossEntropyOp must be one");
PADDLE_ENFORCE(inputs[0] != nullptr && inputs[1] != nullptr,
PADDLE_ENFORCE(ctx.InputVar(0) != nullptr && ctx.InputVar(1) != nullptr,
"Inputs of OnehotCrossEntropyOp must all be set");
PADDLE_ENFORCE(outputs[0] != nullptr,
PADDLE_ENFORCE(ctx.OutputVar(0) != nullptr,
"Outputs of OnehotCrossEntropyOp must all be set");
PADDLE_ENFORCE(inputs[0]->dims().size() == 2, "X's dimension must be 2.");
PADDLE_ENFORCE(outputs[0]->dims().size() == 1,
PADDLE_ENFORCE(ctx.Input(0).dims().size() == 2, "X's dimension must be 2.");
PADDLE_ENFORCE(ctx.Output(0)->dims().size() == 1,
"label's dimension must be 1.");
outputs[0]->Resize(framework::make_ddim({inputs[0]->dims()[0]}));
ctx.Output(0)->Resize(framework::make_ddim({ctx.Input(0).dims()[0]}));
}
};

Expand Down
Loading