-
Notifications
You must be signed in to change notification settings - Fork 1.6k
fix: fix infershape profile #3240
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: fix infershape profile #3240
Conversation
test=develop
d9704c7
to
46b7f03
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -65,6 +65,7 @@ class OpLite : public Registry { | |||
virtual bool CheckShape() const { return true; } | |||
// Inference the outputs' shape. | |||
virtual bool InferShape() const { return true; } | |||
virtual bool SmartInferShape() { return this->InferShape(); } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不需要额外增加一个 SmartInferShape,按需修改单个 Op 特定的 InferShape 实现就可以了
@@ -150,6 +151,10 @@ class OpLite : public Registry { | |||
std::vector<Place> valid_places_; | |||
Place kernel_place_{TARGET(kHost), PRECISION(kFloat)}; | |||
std::unique_ptr<OpInfo> op_info_; | |||
std::vector<DDimLite> last_input_shapes; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
按需修改瓶颈Op 的 InferShape 即可,不要在积累加 member,这样所有创建的 Op 都带上这个 member 了。
在实际需要修改的 Op 上加这个 member,其他非瓶颈 op,就不需要改了。
如果这个升级对所有 op 都适用,再考虑加到基类上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
size_t io_shape_lod_hash_{};
存储上次执行前的 input shape, output shape, lod 相关的 hash combined.
for (int ele in elements) io_shape_lod_hash = hash_combine(io_shape_lod_hash , ele);
if (new_hash != io_shape_lod_hash_) {
InferShape();
}
|
||
this->InferShape(); | ||
|
||
if (!last_input_shapes.empty()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里跟基类里面已经有的逻辑有啥区别吗? 为啥要再实现一次?
@@ -65,6 +65,7 @@ class OpLite : public Registry { | |||
virtual bool CheckShape() const { return true; } | |||
// Inference the outputs' shape. | |||
virtual bool InferShape() const { return true; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
后续可以考虑 InferShape -> InferShapeImpl
SmartInferShape -> InferShape
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
改名可以下个版本
@@ -150,6 +151,10 @@ class OpLite : public Registry { | |||
std::vector<Place> valid_places_; | |||
Place kernel_place_{TARGET(kHost), PRECISION(kFloat)}; | |||
std::unique_ptr<OpInfo> op_info_; | |||
std::vector<DDimLite> last_input_shapes; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
size_t io_shape_lod_hash_{};
存储上次执行前的 input shape, output shape, lod 相关的 hash combined.
for (int ele in elements) io_shape_lod_hash = hash_combine(io_shape_lod_hash , ele);
if (new_hash != io_shape_lod_hash_) {
InferShape();
}
last_output_shapes.clear(); | ||
last_output_lods.clear(); | ||
} | ||
last_output_shapes.push_back(param_.output->dims()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unique_ptr<>
No description provided.