-
Notifications
You must be signed in to change notification settings - Fork 5.9k
【Paddle TensorRT No.45,No.58,No.44】support embedding、unbind、logsigmoid pir-trt convert #70879
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| axis = int(axis) | ||
|
|
||
| size_tensors = [] | ||
| newDims_tensors = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个变量的命名比较奇怪
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
改成了 new_shape_tenosr
| } | ||
| }; | ||
|
|
||
| class EmbeddingOpPattern |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lookup_table只支持动态shape,pir-trt默认支持动态shape,这个应该是无条件进入trt
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
| } | ||
| }; | ||
|
|
||
| class UnbindOpPattern |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
| newDims_tensor = trt_concat(network, newDims_tensors) | ||
| stride = trt.Dims([1] * rank) | ||
| outputs = [] | ||
| output_size = len(paddle_op.results()[0].type().as_vec_type().as_list()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pir下get_output_names()应该能显示output_size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个应该不行吧,这个op只有一个输出pir::VectorType, 要先dyn_cast成pir::VectorType,才能获取size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
哦哦好的
PR Category
Inference
PR Types
Others
Description
pcard-71500
pir-convert:
1.pd.embedding
2.pd.unbind
3.pd.logsigmoid