-
Notifications
You must be signed in to change notification settings - Fork 5.9k
use operator context and infer context #3024
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
299525d
11eabf8
4280a60
dda4881
fb1b3d1
081c7ca
5273c7e
0d693fe
bf3940b
362ba2f
a4bfb61
217186e
1c91df3
2460af4
fb1980e
9ff3595
9fafc46
e87d253
9a2640b
b6764c9
fab7737
e4445d6
eda5493
5f0ed40
bd8872c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -79,6 +79,10 @@ std::vector<std::string> OperatorBase::Outputs(const std::string& name) const { | |
| outputs_.begin() + output_format.at(offset + 1)}; | ||
| } | ||
|
|
||
| void OperatorBase::InferShape(const std::shared_ptr<Scope>& scope) const { | ||
| InferShapeImpl(InferShapeContext(this, scope)); | ||
|
||
| } | ||
|
|
||
| std::string OperatorBase::DebugString() const { | ||
| std::stringstream ss; | ||
| ss << "Op(" << type_ << "), inputs:("; | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -31,22 +31,9 @@ limitations under the License. */ | |
| namespace paddle { | ||
| namespace framework { | ||
|
|
||
| template <typename T> | ||
| struct EigenDeviceConverter; | ||
|
|
||
| template <> | ||
| struct EigenDeviceConverter<platform::CPUPlace> { | ||
| using EigenDeviceType = Eigen::DefaultDevice; | ||
| }; | ||
|
|
||
| #ifndef PADDLE_ONLY_CPU | ||
| template <> | ||
| struct EigenDeviceConverter<platform::GPUPlace> { | ||
| using EigenDeviceType = Eigen::GpuDevice; | ||
| }; | ||
| #endif | ||
|
|
||
| class OperatorBase; | ||
| class InferShapeContext; | ||
| class KernelContext; | ||
| /** | ||
| * OperatorBase has the basic element that Net will call to do computation. | ||
| * Only CreateOperator from OpRegistry will new Operator directly. User | ||
|
|
@@ -84,7 +71,8 @@ class OperatorBase { | |
|
|
||
| /// InferShape infer the size of Variables used by this Operator with | ||
| /// information inside scope | ||
| virtual void InferShape(const std::shared_ptr<Scope>& scope) const = 0; | ||
| virtual void InferShape(const std::shared_ptr<Scope>& scope) const final; | ||
|
||
| virtual void InferShapeImpl(const InferShapeContext& ctx) const = 0; | ||
|
|
||
| /// Net will call this function to Run an op. | ||
| virtual void Run(const std::shared_ptr<Scope>& scope, | ||
|
|
@@ -110,29 +98,32 @@ class OperatorBase { | |
| std::shared_ptr<std::unordered_map<std::string, int>> in_out_idxs_; | ||
| }; | ||
|
|
||
| class KernelContext { | ||
| class OperatorContext { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OperatorContext => ExecutionContext? 我们会有两个概念分别叫做 OperatorContext 和 KernelContext 的吗?如果其实没有,那么就叫 Context 或者 ExecutionContext 是不是更清楚? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ExecutionContext is better There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We actually have two contexts, one for InferShape, other for Run. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See line 35 and 36 |
||
| public: | ||
| KernelContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope, | ||
| const platform::DeviceContext& device_context) | ||
| : op_(*op), scope_(scope), device_context_(device_context) {} | ||
| OperatorContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope) | ||
| : op_(*op), scope_(scope) {} | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Have to check OperatorBase* op not null first. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we donot need to check because context will only be construct inside a op, so op will nevel be null. And so be need not to use std::shared_ptr |
||
|
|
||
| int InputSize() const { return static_cast<int>(op_.inputs_.size()); } | ||
|
||
|
|
||
| const Variable* Input(int index) const { | ||
| int OutputSize() const { return static_cast<int>(op_.outputs_.size()); } | ||
|
|
||
| const Variable* InputVar(int index) const { | ||
|
||
| return scope_->GetVariable(op_.inputs_[index]); | ||
| } | ||
|
|
||
| Variable* Output(int index) const { | ||
| Variable* OutputVar(int index) const { | ||
| return scope_->GetVariable(op_.outputs_[index]); | ||
| } | ||
|
|
||
| const Variable* Input(const std::string& name) const { | ||
| const Variable* InputVar(const std::string& name) const { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 可不可以 typedef std::vector<Tensor*> TensorArray;
TensorArray tensors = var.Get<TensorArray>();There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Still need Isn't |
||
| return scope_->GetVariable(op_.Input(name)); | ||
| } | ||
|
|
||
| const Variable* Output(const std::string& name) const { | ||
| Variable* OutputVar(const std::string& name) const { | ||
| return scope_->GetVariable(op_.Output(name)); | ||
| } | ||
|
|
||
| const std::vector<const Variable*> Inputs(const std::string& name) const { | ||
| const std::vector<const Variable*> InputVars(const std::string& name) const { | ||
| auto names = op_.Inputs(name); | ||
| std::vector<const Variable*> res; | ||
| std::transform( | ||
|
|
@@ -141,7 +132,7 @@ class KernelContext { | |
| return res; | ||
| } | ||
|
|
||
| const std::vector<const Variable*> Outputs(const std::string& name) const { | ||
| const std::vector<const Variable*> OutputVars(const std::string& name) const { | ||
| auto names = op_.Outputs(name); | ||
| std::vector<const Variable*> res; | ||
| std::transform( | ||
|
|
@@ -150,15 +141,80 @@ class KernelContext { | |
| return res; | ||
| } | ||
|
|
||
| const Tensor& Input(int index) const { | ||
| return InputVar(index)->Get<Tensor>(); | ||
| } | ||
|
|
||
| Tensor* Output(int index) const { | ||
| return OutputVar(index)->GetMutable<Tensor>(); | ||
| } | ||
|
|
||
| const Tensor& Input(const std::string& name) const { | ||
| return InputVar(name)->Get<Tensor>(); | ||
| } | ||
|
|
||
| Tensor* Output(const std::string& name) const { | ||
| return OutputVar(name)->GetMutable<Tensor>(); | ||
| } | ||
|
|
||
| const std::vector<const Tensor*> Inputs(const std::string& name) const { | ||
| auto names = op_.Inputs(name); | ||
| std::vector<const Tensor*> res; | ||
| std::transform(names.begin(), names.end(), res.begin(), | ||
| [this](const std::string& name) { | ||
| return &scope_->GetVariable(name)->Get<Tensor>(); | ||
| }); | ||
| return res; | ||
| } | ||
|
|
||
| std::vector<const Tensor*> Outputs(const std::string& name) const { | ||
| auto names = op_.Outputs(name); | ||
| std::vector<const Tensor*> res; | ||
| std::transform(names.begin(), names.end(), res.begin(), | ||
| [this](const std::string& name) { | ||
| return scope_->GetVariable(name)->GetMutable<Tensor>(); | ||
| }); | ||
| return res; | ||
| } | ||
|
|
||
| const OperatorBase& op_; | ||
| const std::shared_ptr<Scope>& scope_; | ||
| }; | ||
|
|
||
| class InferShapeContext : public OperatorContext { | ||
| public: | ||
| InferShapeContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope) | ||
| : OperatorContext(op, scope) {} | ||
| }; | ||
|
|
||
| template <typename T> | ||
| struct EigenDeviceConverter; | ||
|
|
||
| template <> | ||
| struct EigenDeviceConverter<platform::CPUPlace> { | ||
| using EigenDeviceType = Eigen::DefaultDevice; | ||
| }; | ||
|
|
||
| #ifndef PADDLE_ONLY_CPU | ||
| template <> | ||
| struct EigenDeviceConverter<platform::GPUPlace> { | ||
| using EigenDeviceType = Eigen::GpuDevice; | ||
| }; | ||
| #endif | ||
|
|
||
| class KernelContext : public OperatorContext { | ||
| public: | ||
| KernelContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope, | ||
| const platform::DeviceContext& device_context) | ||
| : OperatorContext(op, scope), device_context_(device_context) {} | ||
|
|
||
| template <typename PlaceType, | ||
| typename DeviceType = | ||
| typename EigenDeviceConverter<PlaceType>::EigenDeviceType> | ||
| DeviceType* GetEigenDevice() const; | ||
|
|
||
| platform::Place GetPlace() const { return device_context_.GetPlace(); } | ||
|
|
||
| const OperatorBase& op_; | ||
| const std::shared_ptr<Scope>& scope_; | ||
| const platform::DeviceContext& device_context_; | ||
| }; | ||
|
|
||
|
|
@@ -176,19 +232,6 @@ class OpKernel { | |
| virtual ~OpKernel() {} | ||
| }; | ||
|
|
||
| template <typename T> | ||
| struct VarToTensor {}; | ||
|
|
||
| template <> | ||
| struct VarToTensor<Tensor*> { | ||
| Tensor* operator()(Variable* var) { return var->GetMutable<Tensor>(); } | ||
| }; | ||
|
|
||
| template <> | ||
| struct VarToTensor<const Tensor*> { | ||
| const Tensor* operator()(Variable* var) { return &var->Get<Tensor>(); } | ||
| }; | ||
|
|
||
| class OperatorWithKernel : public OperatorBase { | ||
| public: | ||
| struct OpKernelKey { | ||
|
|
@@ -223,35 +266,6 @@ class OperatorWithKernel : public OperatorBase { | |
| static std::unordered_map<std::string, OpKernelMap> g_all_op_kernels; | ||
| return g_all_op_kernels; | ||
| } | ||
|
|
||
| void InferShape(const std::shared_ptr<Scope>& scope) const final { | ||
| std::vector<const Tensor*> ins; | ||
| VarNamesToTensors(scope, inputs_, &ins); | ||
| std::vector<Tensor*> outs; | ||
| VarNamesToTensors(scope, outputs_, &outs); | ||
| InferShape(ins, outs); | ||
| }; | ||
|
|
||
| private: | ||
| template <typename T> | ||
| void VarNamesToTensors(const std::shared_ptr<Scope>& scope, | ||
| const std::vector<std::string>& var_names, | ||
| std::vector<T>* container) const { | ||
| container->reserve(var_names.size()); | ||
| VarToTensor<T> convert; | ||
| for (auto& name : var_names) { | ||
| auto var = scope->GetVariable(name); | ||
| if (var != nullptr) { | ||
| container->push_back(convert(var)); | ||
| } else { | ||
| container->push_back(nullptr); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| protected: | ||
| virtual void InferShape(const std::vector<const Tensor*>& inputs, | ||
| const std::vector<Tensor*>& outputs) const = 0; | ||
| }; | ||
|
|
||
| } // namespace framework | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In my mind,
Implis usually a suffix for a class name, which implements an interface. Do we really need to name a functionImplhere?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's right, maybe just called
InferShapeis cool.