Skip to content

Commit 2e44eae

Browse files
authored
[Feature] Add NPU operator RotatedFeatureAlign (#2994)
1 parent 5040148 commit 2e44eae

File tree

4 files changed

+59
-3
lines changed

4 files changed

+59
-3
lines changed

docs/en/understand_mmcv/ops.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ We implement common ops used in detection, segmentation, etc.
4141
| PointsInBoxes ||| | | |
4242
| PointsInPolygons | || | ||
4343
| PSAMask |||| ||
44-
| RotatedFeatureAlign |||| | |
44+
| RotatedFeatureAlign |||| | |
4545
| RoIPointPool3d | ||| | |
4646
| RoIPool | ||| ||
4747
| RoIAlignRotated |||| | |

docs/zh_cn/understand_mmcv/ops.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ MMCV 提供了检测、分割等任务中常用的算子
4141
| PointsInBoxes ||| | | |
4242
| PointsInPolygons | || | | |
4343
| PSAMask |||| ||
44-
| RotatedFeatureAlign |||| | |
44+
| RotatedFeatureAlign |||| | |
4545
| RoIPointPool3d | ||| | |
4646
| RoIPool | ||| ||
4747
| RoIAlignRotated |||| | |
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#include "pytorch_npu_helper.hpp"
2+
3+
using namespace NPU_NAME_SPACE;
4+
using namespace std;
5+
6+
void rotated_feature_align_forward_impl(const Tensor features,
7+
const Tensor best_bboxes,
8+
const float spatial_scale,
9+
const int points, Tensor output);
10+
11+
void rotated_feature_align_backward_impl(const Tensor top_grad,
12+
const Tensor best_bboxes,
13+
const float spatial_scale,
14+
const int points, Tensor bottom_grad);
15+
16+
void rotated_feature_align_forward_npu(const Tensor features,
17+
const Tensor best_bboxes,
18+
const float spatial_scale,
19+
const int points, Tensor output) {
20+
int64_t points_ = (int64_t)points;
21+
at::Tensor best_bboxes_ = best_bboxes.transpose(2, 3).transpose(1, 2);
22+
OpCommand cmd;
23+
cmd.Name("RotatedFeatureAlign")
24+
.Input(features)
25+
.Input(best_bboxes_)
26+
.Output(output)
27+
.Attr("spatial_scale", spatial_scale)
28+
.Attr("points", points_)
29+
.Run();
30+
}
31+
32+
void rotated_feature_align_backward_npu(const Tensor top_grad,
33+
const Tensor best_bboxes,
34+
const float spatial_scale,
35+
const int points, Tensor bottom_grad) {
36+
int64_t points_ = (int64_t)points;
37+
at::Tensor best_bboxes_ = best_bboxes.transpose(2, 3).transpose(1, 2);
38+
OpCommand cmd;
39+
cmd.Name("RotatedFeatureAlignGrad")
40+
.Input(top_grad)
41+
.Input(best_bboxes_)
42+
.Output(bottom_grad)
43+
.Attr("spatial_scale", spatial_scale)
44+
.Attr("points", points_)
45+
.Run();
46+
}
47+
48+
REGISTER_NPU_IMPL(rotated_feature_align_forward_impl,
49+
rotated_feature_align_forward_npu);
50+
51+
REGISTER_NPU_IMPL(rotated_feature_align_backward_impl,
52+
rotated_feature_align_backward_npu);

tests/test_ops/test_rotated_feature_align.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44

55
from mmcv.ops import rotated_feature_align
6-
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
6+
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
77

88

99
@pytest.mark.skipif(
@@ -17,6 +17,10 @@
1717
'mlu',
1818
marks=pytest.mark.skipif(
1919
not IS_MLU_AVAILABLE, reason='requires MLU support')),
20+
pytest.param(
21+
'npu',
22+
marks=pytest.mark.skipif(
23+
not IS_NPU_AVAILABLE, reason='requires NPU support')),
2024
pytest.param(
2125
'cpu',
2226
marks=pytest.mark.skipif(

0 commit comments

Comments
 (0)