Skip to content

Conversation

MyPandaShaoxiang
Copy link
Collaborator

No description provided.

Copy link
Collaborator

@chenjiaoAngel chenjiaoAngel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@chenjiaoAngel chenjiaoAngel merged commit 6a0a1f0 into PaddlePaddle:develop Mar 25, 2020
@@ -65,6 +65,7 @@ class OpLite : public Registry {
virtual bool CheckShape() const { return true; }
// Inference the outputs' shape.
virtual bool InferShape() const { return true; }
virtual bool SmartInferShape() { return this->InferShape(); }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不需要额外增加一个 SmartInferShape,按需修改单个 Op 特定的 InferShape 实现就可以了

@@ -150,6 +151,10 @@ class OpLite : public Registry {
std::vector<Place> valid_places_;
Place kernel_place_{TARGET(kHost), PRECISION(kFloat)};
std::unique_ptr<OpInfo> op_info_;
std::vector<DDimLite> last_input_shapes;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

按需修改瓶颈Op 的 InferShape 即可,不要在积累加 member,这样所有创建的 Op 都带上这个 member 了。

在实际需要修改的 Op 上加这个 member,其他非瓶颈 op,就不需要改了。

如果这个升级对所有 op 都适用,再考虑加到基类上

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

size_t io_shape_lod_hash_{};

存储上次执行前的 input shape, output shape, lod 相关的 hash combined.

for (int ele in elements) io_shape_lod_hash = hash_combine(io_shape_lod_hash , ele);

if (new_hash != io_shape_lod_hash_) {
InferShape();
}


this->InferShape();

if (!last_input_shapes.empty()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里跟基类里面已经有的逻辑有啥区别吗? 为啥要再实现一次?

@@ -65,6 +65,7 @@ class OpLite : public Registry {
virtual bool CheckShape() const { return true; }
// Inference the outputs' shape.
virtual bool InferShape() const { return true; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后续可以考虑 InferShape -> InferShapeImpl
SmartInferShape -> InferShape

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改名可以下个版本

@@ -150,6 +151,10 @@ class OpLite : public Registry {
std::vector<Place> valid_places_;
Place kernel_place_{TARGET(kHost), PRECISION(kFloat)};
std::unique_ptr<OpInfo> op_info_;
std::vector<DDimLite> last_input_shapes;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

size_t io_shape_lod_hash_{};

存储上次执行前的 input shape, output shape, lod 相关的 hash combined.

for (int ele in elements) io_shape_lod_hash = hash_combine(io_shape_lod_hash , ele);

if (new_hash != io_shape_lod_hash_) {
InferShape();
}

last_output_shapes.clear();
last_output_lods.clear();
}
last_output_shapes.push_back(param_.output->dims());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unique_ptr<>

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants