- 
                Notifications
    You must be signed in to change notification settings 
- Fork 5.9k
Refactoring InferShape #3946
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactoring InferShape #3946
Changes from 31 commits
125a528
              8ab2d86
              1b7e6e3
              9e1ba61
              fcaea1f
              676c7fd
              703d6ce
              570ebb2
              80a1c63
              41996d2
              fc3b55c
              4e7058e
              267f0e3
              129599d
              92964d6
              53eb75a
              309765c
              93903fb
              872a570
              fc3c095
              8304c74
              59fa374
              afdfeb9
              228ddf8
              175abe6
              060677f
              91134b2
              827cec7
              28b5d0a
              feb9b1d
              df61245
              fb39fb3
              b421e91
              5fa1188
              509c40a
              4364314
              a3436ba
              35ea282
              7053581
              9acdb74
              80c785f
              d9c46c2
              c8e2aa8
              3fbfe5d
              e6ec26f
              c778cec
              5605a30
              40f2b53
              794d5df
              dc0d153
              4e55aae
              32c1d29
              a522653
              cdf9bfc
              e22d7c4
              d2d500e
              997cbc5
              50cec66
              01bfcdd
              017b27c
              f3bd1ad
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -1,4 +1,5 @@ | ||
| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|  | ||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|  | @@ -15,6 +16,7 @@ limitations under the License. */ | |
|  | ||
| #include "paddle/framework/attribute.h" | ||
| #include "paddle/framework/framework.pb.h" | ||
| #include "paddle/framework/shape_inference.h" | ||
|  | ||
| namespace paddle { | ||
| namespace framework { | ||
|  | @@ -70,11 +72,26 @@ class OpProtoAndCheckerMaker { | |
|  | ||
| void AddComment(const std::string& comment) { proto_->set_comment(comment); } | ||
|  | ||
| void SetShapeInferenceFn(ShapeInferenceFn fn) { shape_infer_fn_ = fn; } | ||
|  | ||
| void SetGradShapeInferenceFn(ShapeInferenceFn fn) { | ||
| grad_shape_infer_fn_ = fn; | ||
| } | ||
|  | ||
| public: | ||
| const ShapeInferenceFn GetShapeInferenceFn() const { return shape_infer_fn_; } | ||
|          | ||
|  | ||
| const ShapeInferenceFn GetGradShapeInferenceFn() const { | ||
|          | ||
| return grad_shape_infer_fn_; | ||
| } | ||
|  | ||
| private: | ||
| void CheckNoDuplicatedInOutAttrs(); | ||
|  | ||
| OpProto* proto_; | ||
| OpAttrChecker* op_checker_; | ||
| ShapeInferenceFn shape_infer_fn_{nullptr}; | ||
| ShapeInferenceFn grad_shape_infer_fn_{nullptr}; | ||
| bool validated_{false}; | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why we need a flag of  There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. these code is removed | ||
| }; | ||
|  | ||
|  | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -27,6 +27,7 @@ limitations under the License. */ | |
| #include "paddle/framework/op_proto_maker.h" | ||
| #include "paddle/framework/operator.h" | ||
| #include "paddle/framework/scope.h" | ||
| #include "paddle/framework/shape_inference_impl.h" | ||
|  | ||
| namespace paddle { | ||
| namespace framework { | ||
|  | @@ -35,10 +36,12 @@ class OpRegistry { | |
| public: | ||
| template <typename OpType, typename ProtoMakerType, typename GradOpType> | ||
| static void RegisterOp(const std::string& op_type, | ||
| const std::string& grad_op_type) { | ||
| const std::string& grad_op_type, | ||
| const ShapeInferenceFn fn) { | ||
|          | ||
| PADDLE_ENFORCE(!OpInfoMap::Instance().Has(op_type), | ||
| "'%s' is registered more than once.", op_type); | ||
| OpInfo op_info; | ||
| ShapeInferenceFn grad_op_inferer = nullptr; | ||
| op_info.creator_ = []( | ||
| const std::string& type, const VariableNameMap& inputs, | ||
| const VariableNameMap& outputs, const AttributeMap& attrs) { | ||
|  | @@ -52,18 +55,21 @@ class OpRegistry { | |
| auto maker = ProtoMakerType(op_info.proto_, op_info.checker_); | ||
| maker.Validate(); | ||
| op_info.proto_->set_type(op_type); | ||
| op_info.shapeInferFn_ = maker.GetShapeInferenceFn(); | ||
| grad_op_inferer = maker.GetGradShapeInferenceFn(); | ||
| PADDLE_ENFORCE( | ||
| op_info.proto_->IsInitialized(), | ||
| "Fail to initialize %s's OpProto, because %s is not initialized", | ||
| op_type, op_info.proto_->InitializationErrorString()); | ||
| } else { | ||
| op_info.proto_ = nullptr; | ||
| op_info.checker_ = nullptr; | ||
| op_info.shapeInferFn_ = fn; | ||
| } | ||
| OpInfoMap::Instance().Insert(op_type, op_info); | ||
| // register gradient op | ||
| if (!grad_op_type.empty()) { | ||
| RegisterOp<GradOpType, NOPMaker, NOP>(grad_op_type, ""); | ||
| RegisterOp<GradOpType, NOPMaker, NOP>(grad_op_type, "", grad_op_inferer); | ||
| } | ||
| } | ||
|  | ||
|  | @@ -75,6 +81,20 @@ class OpRegistry { | |
| static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc); | ||
|  | ||
| static std::unique_ptr<OperatorBase> CreateGradOp(const OperatorBase& op); | ||
|  | ||
| // compile time InferShape | ||
| static void InferShape(const OpDesc& op_desc, | ||
|          | ||
| std::map<std::string, VarDesc*>& var_descs) { | ||
| auto& info = OpInfoMap::Instance().Get(op_desc.type()); | ||
| auto op = OpRegistry::CreateOp(op_desc); | ||
|          | ||
| info.shapeInferFn_(CompileTimeInferShapeContext(op, var_descs)); | ||
| } | ||
|  | ||
| // runtime InferShape | ||
| static void InferShape(const OperatorBase& op, const Scope& scope) { | ||
| auto& info = OpInfoMap::Instance().Get(op.Type()); | ||
| info.shapeInferFn_(RunTimeInferShapeContext(op, scope)); | ||
| } | ||
| }; | ||
|  | ||
| class Registrar { | ||
|  | @@ -95,8 +115,8 @@ class OpRegistrar : public Registrar { | |
| public: | ||
| explicit OpRegistrar(const char* op_type) { OpRegistrar(op_type, ""); } | ||
| OpRegistrar(const char* op_type, const char* grad_op_type) { | ||
| OpRegistry::RegisterOp<OpType, ProtoMakerType, GradOpType>(op_type, | ||
| grad_op_type); | ||
| OpRegistry::RegisterOp<OpType, ProtoMakerType, GradOpType>( | ||
| op_type, grad_op_type, nullptr); | ||
| } | ||
| }; | ||
|  | ||
|  | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and | |
| limitations under the License. */ | ||
|  | ||
| #include "paddle/framework/operator.h" | ||
| #include <algorithm> | ||
| #include "paddle/framework/op_registry.h" | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why change these lines. If it is not necessary, please leave them unchanged. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 
 | ||
|  | ||
| namespace paddle { | ||
| namespace framework { | ||
|  | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,47 @@ | ||
| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|  | ||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|  | ||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|  | ||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|  | ||
| #pragma once | ||
|  | ||
| #include "paddle/framework/ddim.h" | ||
|  | ||
| namespace paddle { | ||
| namespace framework { | ||
|  | ||
| class InferShapeContextBase; | ||
|  | ||
| using ShapeInferenceFn = | ||
| std::function<void(const framework::InferShapeContextBase& ctx)>; | ||
|          | ||
|  | ||
| class InferShapeContextBase { | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 为什么需要一个基类? 2、另外,在定义op的ShapeInference时候,只需要暴露CompileTime的InferShapeContext,不需要暴露基类。     SetShapeInferenceFn([](const framework::InferShapeContextBase &ctx) {
      auto dim0 = ctx.get_input_dim("X");
      auto dim1 = ctx.get_input_dim("Y");
      PADDLE_ENFORCE_EQ(dim0.size(), 2,
                        "input X should be a tensor with 2 dims, a matrix");
      PADDLE_ENFORCE_EQ(dim1.size(), 2,
                        "input Y should be a tensor with 2 dims, a matrix");
      PADDLE_ENFORCE_EQ(dim0[1], dim1[0],
                        "First matrix's width must be equal "
                        "with second matrix's height.");
      ctx.set_output_dim("Out", {dim0[0], dim1[1]});
    });
  }There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 
 并不是这样的,compile time的infershape和runtime的infershape并没有太多本质区别,都要做: 
 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 经过讨论,这个部分已经统一意见 | ||
| public: | ||
| virtual ~InferShapeContextBase() {} | ||
| virtual framework::DDim get_input_dim(const std::string& name) const = 0; | ||
| virtual void set_input_dim(const std::string& name, | ||
| const framework::DDim& dim) const = 0; | ||
| virtual framework::DDim get_output_dim(const std::string& name) const = 0; | ||
| virtual void set_output_dim(const std::string& name, | ||
| const DDim& dim) const = 0; | ||
| virtual AttrReader attrs() const = 0; | ||
|  | ||
| protected: | ||
| virtual framework::DDim get_dim(const std::string& name) const = 0; | ||
| virtual void set_dim(const std::string& name, | ||
| const framework::DDim& dim) const = 0; | ||
| }; | ||
|  | ||
| inline void NonFn(const framework::InferShapeContextBase& ctx){}; | ||
|  | ||
| } // namespace framework | ||
| } // namespace paddle | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,111 @@ | ||
| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|  | ||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|  | ||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|  | ||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|  | ||
| #pragma once | ||
|  | ||
| #include "paddle/framework/ddim.h" | ||
| #include "paddle/framework/operator.h" | ||
| #include "paddle/framework/shape_inference.h" | ||
|  | ||
| namespace paddle { | ||
| namespace framework { | ||
|  | ||
| class CompileTimeInferShapeContext : public InferShapeContextBase { | ||
| public: | ||
| CompileTimeInferShapeContext(std::unique_ptr<OperatorBase>& op, | ||
| std::map<std::string, VarDesc*>& var_descs) | ||
| : op_(std::move(op)), var_descs_(var_descs) {} | ||
|  | ||
| DDim get_input_dim(const std::string& name) const { | ||
| return get_dim(op_->Input(name)); | ||
| } | ||
|  | ||
| void set_input_dim(const std::string& name, const DDim& dim) const { | ||
| set_dim(op_->Input(name), dim); | ||
| } | ||
|  | ||
| DDim get_output_dim(const std::string& name) const { | ||
| return get_dim(op_->Output(name)); | ||
| } | ||
|  | ||
| void set_output_dim(const std::string& name, const DDim& dim) const { | ||
| set_dim(op_->Output(name), dim); | ||
| } | ||
|  | ||
| AttrReader attrs() const { return AttrReader(op_->Attrs()); } | ||
|  | ||
| private: | ||
| DDim get_dim(const std::string& name) const { | ||
| VarDesc* desc = var_descs_.at(name); | ||
| std::vector<int64_t> dim; | ||
| int length = desc->lod_tensor().dims().size(); | ||
| dim.reserve(length); | ||
| std::copy(desc->lod_tensor().dims().begin(), | ||
| desc->lod_tensor().dims().end(), std::back_inserter(dim)); | ||
| return make_ddim(dim); | ||
| } | ||
|  | ||
| void set_dim(const std::string& name, const DDim& dim) const { | ||
| VarDesc* desc = var_descs_.at(name); | ||
| auto tensor = desc->mutable_lod_tensor(); | ||
| tensor->clear_dims(); | ||
| for (int i = 0; i < dim.size(); ++i) { | ||
| tensor->add_dims(static_cast<int>(dim[i])); | ||
| } | ||
| } | ||
|  | ||
| std::unique_ptr<OperatorBase> op_; | ||
| std::map<std::string, VarDesc*>& var_descs_; | ||
| }; | ||
|  | ||
| class RunTimeInferShapeContext : public InferShapeContextBase { | ||
| public: | ||
| RunTimeInferShapeContext(const OperatorBase& op, const Scope& scope) | ||
| : op_(op), scope_(scope) {} | ||
|  | ||
| DDim get_input_dim(const std::string& name) const { | ||
| return get_dim(op_.Input(name)); | ||
| } | ||
|  | ||
| void set_input_dim(const std::string& name, const DDim& dim) const { | ||
| set_dim(op_.Input(name), dim); | ||
| } | ||
|  | ||
| DDim get_output_dim(const std::string& name) const { | ||
| return get_dim(op_.Output(name)); | ||
| } | ||
|  | ||
| void set_output_dim(const std::string& name, const DDim& dim) const { | ||
| set_dim(op_.Output(name), dim); | ||
| } | ||
|  | ||
| AttrReader attrs() const { return AttrReader(op_.Attrs()); } | ||
|  | ||
| private: | ||
| DDim get_dim(const std::string& name) const { | ||
| Tensor* t = scope_.FindVar(op_.Input(name))->GetMutable<Tensor>(); | ||
| return t->dims(); | ||
| } | ||
|  | ||
| void set_dim(const std::string& name, const DDim& dim) const { | ||
| Tensor* t = scope_.FindVar(name)->GetMutable<Tensor>(); | ||
| t->Resize(dim); | ||
| } | ||
|  | ||
| const OperatorBase& op_; | ||
| const Scope& scope_; | ||
| }; | ||
|  | ||
| } // namespace framework | ||
| } // namespace paddle | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that the purpose of this class is to add a method
Get, who can find and read an entry in AttributeMap.How about we change the definition of AttributeMap from the current one
typedef std::unordered_map<std::string, Attribute> AttributeMap;into
so could we define
Getas a method of typeAttributeMap; instead of adding a new typeAttrReader?Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the problem is if we use a Class that inherit from map, then we can't directly use the list_initialization to init this AttributeMap,
like scale_op:
AttributeMap attrs = {{"scale", Attr<AttrType>("scale")}};try to find a better way to do this in next PR