- 
                Notifications
    You must be signed in to change notification settings 
- Fork 5.9k
use operator context and infer context #3024
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
299525d
              11eabf8
              4280a60
              dda4881
              fb1b3d1
              081c7ca
              5273c7e
              0d693fe
              bf3940b
              362ba2f
              a4bfb61
              217186e
              1c91df3
              2460af4
              fb1980e
              9ff3595
              9fafc46
              e87d253
              9a2640b
              b6764c9
              fab7737
              e4445d6
              eda5493
              5f0ed40
              bd8872c
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -79,6 +79,10 @@ std::vector<std::string> OperatorBase::Outputs(const std::string& name) const { | |
| outputs_.begin() + output_format.at(offset + 1)}; | ||
| } | ||
|  | ||
| void OperatorBase::InferShape(const std::shared_ptr<Scope>& scope) const { | ||
| InferShapeImpl(InferShapeContext(this, scope)); | ||
|          | ||
| } | ||
|  | ||
| std::string OperatorBase::DebugString() const { | ||
| std::stringstream ss; | ||
| ss << "Op(" << type_ << "), inputs:("; | ||
|  | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -31,22 +31,9 @@ limitations under the License. */ | |
| namespace paddle { | ||
| namespace framework { | ||
|  | ||
| template <typename T> | ||
| struct EigenDeviceConverter; | ||
|  | ||
| template <> | ||
| struct EigenDeviceConverter<platform::CPUPlace> { | ||
| using EigenDeviceType = Eigen::DefaultDevice; | ||
| }; | ||
|  | ||
| #ifndef PADDLE_ONLY_CPU | ||
| template <> | ||
| struct EigenDeviceConverter<platform::GPUPlace> { | ||
| using EigenDeviceType = Eigen::GpuDevice; | ||
| }; | ||
| #endif | ||
|  | ||
| class OperatorBase; | ||
| class InferShapeContext; | ||
| class KernelContext; | ||
| /** | ||
| * OperatorBase has the basic element that Net will call to do computation. | ||
| * Only CreateOperator from OpRegistry will new Operator directly. User | ||
|  | @@ -84,7 +71,8 @@ class OperatorBase { | |
|  | ||
| /// InferShape infer the size of Variables used by this Operator with | ||
| /// information inside scope | ||
| virtual void InferShape(const std::shared_ptr<Scope>& scope) const = 0; | ||
| virtual void InferShape(const std::shared_ptr<Scope>& scope) const final; | ||
|          | ||
| virtual void InferShapeImpl(const InferShapeContext& ctx) const = 0; | ||
|  | ||
| /// Net will call this function to Run an op. | ||
| virtual void Run(const std::shared_ptr<Scope>& scope, | ||
|  | @@ -110,29 +98,32 @@ class OperatorBase { | |
| std::shared_ptr<std::unordered_map<std::string, int>> in_out_idxs_; | ||
| }; | ||
|  | ||
| class KernelContext { | ||
| class OperatorContext { | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OperatorContext => ExecutionContext? 我们会有两个概念分别叫做 OperatorContext 和 KernelContext 的吗?如果其实没有,那么就叫 Context 或者 ExecutionContext 是不是更清楚? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ExecutionContext is better There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We actually have two contexts, one for InferShape, other for Run. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See line 35 and 36 | ||
| public: | ||
| KernelContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope, | ||
| const platform::DeviceContext& device_context) | ||
| : op_(*op), scope_(scope), device_context_(device_context) {} | ||
| OperatorContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope) | ||
| : op_(*op), scope_(scope) {} | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Have to check OperatorBase* op not null first. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we donot need to check because context will only be construct inside a op, so op will nevel be null. And so be need not to use std::shared_ptr | ||
|  | ||
| int InputSize() const { return static_cast<int>(op_.inputs_.size()); } | ||
|          | ||
|  | ||
| const Variable* Input(int index) const { | ||
| int OutputSize() const { return static_cast<int>(op_.outputs_.size()); } | ||
|  | ||
| const Variable* InputVar(int index) const { | ||
|          | ||
| return scope_->GetVariable(op_.inputs_[index]); | ||
| } | ||
|  | ||
| Variable* Output(int index) const { | ||
| Variable* OutputVar(int index) const { | ||
| return scope_->GetVariable(op_.outputs_[index]); | ||
| } | ||
|  | ||
| const Variable* Input(const std::string& name) const { | ||
| const Variable* InputVar(const std::string& name) const { | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 可不可以  typedef std::vector<Tensor*> TensorArray;
TensorArray tensors = var.Get<TensorArray>();There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Still need  Isn't  | ||
| return scope_->GetVariable(op_.Input(name)); | ||
| } | ||
|  | ||
| const Variable* Output(const std::string& name) const { | ||
| Variable* OutputVar(const std::string& name) const { | ||
| return scope_->GetVariable(op_.Output(name)); | ||
| } | ||
|  | ||
| const std::vector<const Variable*> Inputs(const std::string& name) const { | ||
| const std::vector<const Variable*> InputVars(const std::string& name) const { | ||
| auto names = op_.Inputs(name); | ||
| std::vector<const Variable*> res; | ||
| std::transform( | ||
|  | @@ -141,7 +132,7 @@ class KernelContext { | |
| return res; | ||
| } | ||
|  | ||
| const std::vector<const Variable*> Outputs(const std::string& name) const { | ||
| const std::vector<const Variable*> OutputVars(const std::string& name) const { | ||
| auto names = op_.Outputs(name); | ||
| std::vector<const Variable*> res; | ||
| std::transform( | ||
|  | @@ -150,15 +141,80 @@ class KernelContext { | |
| return res; | ||
| } | ||
|  | ||
| const Tensor& Input(int index) const { | ||
| return InputVar(index)->Get<Tensor>(); | ||
| } | ||
|  | ||
| Tensor* Output(int index) const { | ||
| return OutputVar(index)->GetMutable<Tensor>(); | ||
| } | ||
|  | ||
| const Tensor& Input(const std::string& name) const { | ||
| return InputVar(name)->Get<Tensor>(); | ||
| } | ||
|  | ||
| Tensor* Output(const std::string& name) const { | ||
| return OutputVar(name)->GetMutable<Tensor>(); | ||
| } | ||
|  | ||
| const std::vector<const Tensor*> Inputs(const std::string& name) const { | ||
| auto names = op_.Inputs(name); | ||
| std::vector<const Tensor*> res; | ||
| std::transform(names.begin(), names.end(), res.begin(), | ||
| [this](const std::string& name) { | ||
| return &scope_->GetVariable(name)->Get<Tensor>(); | ||
| }); | ||
| return res; | ||
| } | ||
|  | ||
| std::vector<const Tensor*> Outputs(const std::string& name) const { | ||
| auto names = op_.Outputs(name); | ||
| std::vector<const Tensor*> res; | ||
| std::transform(names.begin(), names.end(), res.begin(), | ||
| [this](const std::string& name) { | ||
| return scope_->GetVariable(name)->GetMutable<Tensor>(); | ||
| }); | ||
| return res; | ||
| } | ||
|  | ||
| const OperatorBase& op_; | ||
| const std::shared_ptr<Scope>& scope_; | ||
| }; | ||
|  | ||
| class InferShapeContext : public OperatorContext { | ||
| public: | ||
| InferShapeContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope) | ||
| : OperatorContext(op, scope) {} | ||
| }; | ||
|  | ||
| template <typename T> | ||
| struct EigenDeviceConverter; | ||
|  | ||
| template <> | ||
| struct EigenDeviceConverter<platform::CPUPlace> { | ||
| using EigenDeviceType = Eigen::DefaultDevice; | ||
| }; | ||
|  | ||
| #ifndef PADDLE_ONLY_CPU | ||
| template <> | ||
| struct EigenDeviceConverter<platform::GPUPlace> { | ||
| using EigenDeviceType = Eigen::GpuDevice; | ||
| }; | ||
| #endif | ||
|  | ||
| class KernelContext : public OperatorContext { | ||
| public: | ||
| KernelContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope, | ||
| const platform::DeviceContext& device_context) | ||
| : OperatorContext(op, scope), device_context_(device_context) {} | ||
|  | ||
| template <typename PlaceType, | ||
| typename DeviceType = | ||
| typename EigenDeviceConverter<PlaceType>::EigenDeviceType> | ||
| DeviceType* GetEigenDevice() const; | ||
|  | ||
| platform::Place GetPlace() const { return device_context_.GetPlace(); } | ||
|  | ||
| const OperatorBase& op_; | ||
| const std::shared_ptr<Scope>& scope_; | ||
| const platform::DeviceContext& device_context_; | ||
| }; | ||
|  | ||
|  | @@ -176,19 +232,6 @@ class OpKernel { | |
| virtual ~OpKernel() {} | ||
| }; | ||
|  | ||
| template <typename T> | ||
| struct VarToTensor {}; | ||
|  | ||
| template <> | ||
| struct VarToTensor<Tensor*> { | ||
| Tensor* operator()(Variable* var) { return var->GetMutable<Tensor>(); } | ||
| }; | ||
|  | ||
| template <> | ||
| struct VarToTensor<const Tensor*> { | ||
| const Tensor* operator()(Variable* var) { return &var->Get<Tensor>(); } | ||
| }; | ||
|  | ||
| class OperatorWithKernel : public OperatorBase { | ||
| public: | ||
| struct OpKernelKey { | ||
|  | @@ -223,35 +266,6 @@ class OperatorWithKernel : public OperatorBase { | |
| static std::unordered_map<std::string, OpKernelMap> g_all_op_kernels; | ||
| return g_all_op_kernels; | ||
| } | ||
|  | ||
| void InferShape(const std::shared_ptr<Scope>& scope) const final { | ||
| std::vector<const Tensor*> ins; | ||
| VarNamesToTensors(scope, inputs_, &ins); | ||
| std::vector<Tensor*> outs; | ||
| VarNamesToTensors(scope, outputs_, &outs); | ||
| InferShape(ins, outs); | ||
| }; | ||
|  | ||
| private: | ||
| template <typename T> | ||
| void VarNamesToTensors(const std::shared_ptr<Scope>& scope, | ||
| const std::vector<std::string>& var_names, | ||
| std::vector<T>* container) const { | ||
| container->reserve(var_names.size()); | ||
| VarToTensor<T> convert; | ||
| for (auto& name : var_names) { | ||
| auto var = scope->GetVariable(name); | ||
| if (var != nullptr) { | ||
| container->push_back(convert(var)); | ||
| } else { | ||
| container->push_back(nullptr); | ||
| } | ||
| } | ||
| } | ||
|  | ||
| protected: | ||
| virtual void InferShape(const std::vector<const Tensor*>& inputs, | ||
| const std::vector<Tensor*>& outputs) const = 0; | ||
| }; | ||
|  | ||
| } // namespace framework | ||
|  | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In my mind,
Implis usually a suffix for a class name, which implements an interface. Do we really need to name a functionImplhere?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's right, maybe just called
InferShapeis cool.