-
Notifications
You must be signed in to change notification settings - Fork 295
【Hackathon No.67】为神经网络编译器 CINN 增加 arange 算子 #195
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
zhhsplendid
merged 10 commits into
PaddlePaddle:master
from
MayYouBeProsperous:hackathon
Aug 10, 2022
Merged
Changes from 5 commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
6206cd3
【Hackathon No.67】为神经网络编译器 CINN 增加 arange 算子
MayYouBeProsperous 7992bd5
Update 20200802_cinn_api_design_arange.md
MayYouBeProsperous 1a22269
Update 20200802_cinn_api_design_arange.md
MayYouBeProsperous 018771e
Update 20200802_cinn_api_design_arange.md
MayYouBeProsperous 1dc7a9c
Update 20200802_cinn_api_design_arange.md
MayYouBeProsperous f4967a3
Update 20200802_cinn_api_design_arange.md
MayYouBeProsperous 8001cc3
Update 20200802_cinn_api_design_arange.md
MayYouBeProsperous 26a653f
Update 20200802_cinn_api_design_arange.md
MayYouBeProsperous f4f99d1
Update 20200802_cinn_api_design_arange.md
MayYouBeProsperous 8cba15f
Update 20200802_cinn_api_design_arange.md
MayYouBeProsperous File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,179 @@ | ||
| # CINN arange 设计文档 | ||
| |API名称 | 新增API名称 | | ||
| |---|---| | ||
| |提交作者<input type="checkbox" class="rowselector hidden"> | MayYouBeProsperous | | ||
| |提交时间<input type="checkbox" class="rowselector hidden"> | 2022-08-02 | | ||
| |版本号 | V1.0 | | ||
| |依赖CINN版本<input type="checkbox" class="rowselector hidden"> | develop | | ||
| |文件名 | 20200802_cinn_api_design_arange.md<br> | | ||
|
|
||
|
|
||
| # 一、概述 | ||
|
|
||
| ## 1、相关背景 | ||
| `arange`是神经网络编译器中基础的算子,向算子输入一个数值区间的边界,以及步长`step`,算子输出一个间隔相等的序列 | ||
|
|
||
| ## 2、名词解释 | ||
| tensor:张量,形式为多维数组 | ||
| step:步长 ,序列中相邻的两个元素的差值 | ||
|
|
||
| ## 3、功能目标 | ||
| 实现`arange`算子。 | ||
| 算子输入起始 $start$ ,终点 $end$ ,以及步长 $step$ | ||
| 算子输出序列 $(x_0,x_1,...,x_n)$ ,序列长度 $n=\left [ (start-end)/step \right ]$,序列满足 $x_0=start$ , $x_{i+1} - x_i = step$ $(0 \leqslant i < n)$ | ||
| 算子的输入参数可能有异常情况,并且部分输入参数可缺省,需考虑处理。 | ||
| ## 4、意义 | ||
| 实现`arange`算子,将进一步CINN的基础算子库。 | ||
|
|
||
| # 二、CINN现状 | ||
| CINN框架暂不支持`arange`算子,需要实现。 | ||
|
|
||
| # 三、业内方案调研 | ||
| 1. tvm实现`arange`的核心代码如下,使用了lambda表达式实现功能 | ||
| ```c++ | ||
| inline Tensor arange(const PrimExpr& start, const PrimExpr& stop, const PrimExpr& step, | ||
| DataType dtype, std::string name = "T_arange", std::string tag = kInjective) { | ||
| PrimExpr num_elem = tvm::cast( | ||
| tvm::DataType::Int(32), tvm::ceil(tvm::cast(tvm::DataType::Float(32), stop - start) / step)); | ||
| Array<PrimExpr> shape; | ||
| return compute( | ||
| {num_elem}, | ||
| [&](const Array<Var>& indices) { return tvm::cast(dtype, start + step * indices[0]); }, name, | ||
| tag); | ||
| } | ||
| ``` | ||
| 2. xla实现`arange`的核心代码如下 | ||
| ```c++ | ||
| torch::lazy::NodePtr ARange(const at::Scalar& start, const at::Scalar& end, | ||
| const at::Scalar& step, | ||
| at::ScalarType scalar_type) { | ||
| xla::PrimitiveType type = MakeXlaPrimitiveType(scalar_type, | ||
| /*device=*/nullptr); | ||
| XLA_CHECK_NE(step.toDouble(), 0.0); | ||
| XLA_CHECK(!std::isnan(start.toDouble()) && !std::isnan(end.toDouble())) | ||
| << "unsupported range: " << start.toDouble() << " -> " << end.toDouble(); | ||
| XLA_CHECK((start.toDouble() <= end.toDouble() && step.toDouble() > 0.0) || | ||
| (start.toDouble() >= end.toDouble() && step.toDouble() < 0.0)); | ||
| xla::Literal values; | ||
| switch (type) { | ||
| case xla::PrimitiveType::BF16: | ||
| values = XlaHelpers::Range<tensorflow::bfloat16>( | ||
| static_cast<tensorflow::bfloat16>(start.toFloat()), | ||
| static_cast<tensorflow::bfloat16>(end.toFloat()), | ||
| static_cast<tensorflow::bfloat16>(step.toFloat())); | ||
| break; | ||
| case xla::PrimitiveType::F16: | ||
| values = | ||
| XlaHelpers::Range<xla::half>(static_cast<xla::half>(start.toHalf()), | ||
| static_cast<xla::half>(end.toHalf()), | ||
| static_cast<xla::half>(step.toHalf())); | ||
| break; | ||
| case xla::PrimitiveType::F32: | ||
| values = XlaHelpers::Range<float>(start.toFloat(), end.toFloat(), | ||
| step.toFloat()); | ||
| break; | ||
| case xla::PrimitiveType::F64: | ||
| values = XlaHelpers::Range<double>(start.toDouble(), end.toDouble(), | ||
| step.toDouble()); | ||
| break; | ||
| case xla::PrimitiveType::U8: | ||
| values = XlaHelpers::Range<uint8_t>(start.toByte(), end.toByte(), | ||
| step.toByte()); | ||
| break; | ||
| case xla::PrimitiveType::S8: | ||
| values = XlaHelpers::Range<int8_t>(start.toChar(), end.toChar(), | ||
| step.toChar()); | ||
| break; | ||
| case xla::PrimitiveType::S16: | ||
| values = XlaHelpers::Range<int16_t>(start.toShort(), end.toShort(), | ||
| step.toShort()); | ||
| break; | ||
| case xla::PrimitiveType::U16: | ||
| values = | ||
| XlaHelpers::Range<uint16_t>(start.toInt(), end.toInt(), step.toInt()); | ||
| break; | ||
| case xla::PrimitiveType::S32: | ||
| values = | ||
| XlaHelpers::Range<int32_t>(start.toInt(), end.toInt(), step.toInt()); | ||
| break; | ||
| case xla::PrimitiveType::U32: | ||
| values = XlaHelpers::Range<uint32_t>(start.toLong(), end.toLong(), | ||
| step.toLong()); | ||
| break; | ||
| case xla::PrimitiveType::S64: | ||
| values = XlaHelpers::Range<int64_t>(start.toLong(), end.toLong(), | ||
| step.toLong()); | ||
| break; | ||
| case xla::PrimitiveType::U64: | ||
| values = XlaHelpers::Range<uint64_t>(start.toLong(), end.toLong(), | ||
| step.toLong()); | ||
| break; | ||
| default: | ||
| XLA_ERROR() << "XLA type not supported: " << type; | ||
| } | ||
| return torch::lazy::MakeNode<Constant>(std::move(values)); | ||
| } | ||
| ``` | ||
| 主要函数XlaHelpers::Range的实现 | ||
| ```c++ | ||
| template <typename T> | ||
| static xla::Literal Range(T start, T end, T step) { | ||
| return xla::LiteralUtil::CreateR1<T>(xla::util::Range<T>(start, end, step)); | ||
| } | ||
|
|
||
| //xla::util::Range | ||
| template <typename T> | ||
| std::vector<T> Range(T start, T end, T step = 1) { | ||
| std::vector<T> result; | ||
| result.reserve(static_cast<size_t>((end - start) / step)); | ||
| if (start < end) { | ||
| for (; start < end; start += step) { | ||
| result.push_back(start); | ||
| } | ||
| } else { | ||
| for (; start > end; start += step) { | ||
| result.push_back(start); | ||
| } | ||
| } | ||
| return result; | ||
| } | ||
| ``` | ||
|
|
||
| # 四、对比分析 | ||
| tvm与xla的arange实现方法基本类似。 | ||
|
|
||
| # 五、设计思路与实现方案 | ||
|
|
||
| ## 命名与参数设计 | ||
| start:区间起点(且区间包括此值),默认值为0。 | ||
| end:区间终点(且通常区间不包括此值),默认值为None。 | ||
zhhsplendid marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| step:均匀分割的步长,默认值为1。 | ||
| dtype:输出`tensor`的数据类型,支持int32、int64、float32、float64。当该参数值为None时, 输出`tensor`的数据类型为int64。默认值为None。 | ||
| ## 底层OP设计 | ||
| 1. 在 `cinn/hlir/op/contrib/arange.h` 里声明`arange`算子。 | ||
| 2. 在 `cinn/hlir/op/contrib/arange.cc` 里实现`arange`算子和 `strategy`。 | ||
| ## API实现方案 | ||
| 1. 在 `cinn/frontend/base_build.h` 里声明 `BaseBuilder::Arange`。 | ||
zhhsplendid marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| 2. 在 `cinn/frontend/base_build.cc` 里实现 `BaseBuilder::Arange`。 | ||
| 3. 在 `cinn/pybind/frontend` 对 Python 类 `BaseBuilder` 添加 `arange` 接口,并绑定到 `BaseBuilder::Arange`。 | ||
| 4. 上层 `load_paddle_model` 调用提交到 `cinn/frontend/paddle_model_to_program.h` 和 `.cc` 文件下。 | ||
|
|
||
| python通过Builder类的方法调用`arange`。 | ||
| ```python | ||
| builder = CinnBuilder("test_basic") | ||
| b = builder.arange(1,10,1,"int32") | ||
| ``` | ||
| # 六、测试和验收的考量 | ||
| 1. 提供基础的 demo 文件。 | ||
| 2. 在`cinn/hlir/op/contrib/arange_test.cc`中添加对底层OP进行测试的代码。 | ||
| 3. 在`cinn/frontend/net_builder_test.cc`中添加对前端的测试。 | ||
| 4. 提交 API 说明到相应的文档中。 | ||
| # 七、可行性分析和排期规划 | ||
| - 可行性分析:CINN已实现Builder、Expr IR、算子注册等模块,在CINN已有的框架基础上能够很好地增加算子功能。 | ||
| - 排期规划:预计9月1日完成算子实现、功能测试以及文档 | ||
|
|
||
| # 八、影响面 | ||
| 对其他模块无影响。 | ||
|
|
||
| # 附件及参考资料 | ||
| [深度学习框架开发指南-飞桨黑客松3.0](https://aistudio.baidu.com/aistudio/course/introduce/26351?directly=1&shared=1) | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.