Skip to content

Commit 69d6051

Browse files
authored
add F_pairwise_distance to pnnx and ncnn (#4942)
1 parent 1d7720e commit 69d6051

File tree

4 files changed

+104
-0
lines changed

4 files changed

+104
-0
lines changed

tools/pnnx/src/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ set(pnnx_pass_level2_SRCS
154154
pass_level2/F_mish.cpp
155155
pass_level2/F_normalize.cpp
156156
pass_level2/F_pad.cpp
157+
pass_level2/F_pairwise_distance.cpp
157158
pass_level2/F_pixel_shuffle.cpp
158159
pass_level2/F_pixel_unshuffle.cpp
159160
pass_level2/F_prelu.cpp
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// Tencent is pleased to support the open source community by making ncnn available.
2+
//
3+
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
4+
//
5+
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6+
// in compliance with the License. You may obtain a copy of the License at
7+
//
8+
// https://opensource.org/licenses/BSD-3-Clause
9+
//
10+
// Unless required by applicable law or agreed to in writing, software distributed
11+
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12+
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13+
// specific language governing permissions and limitations under the License.
14+
15+
#include "pass_level2.h"
16+
17+
namespace pnnx {
18+
19+
class F_pairwise_distance : public GraphRewriterPass
20+
{
21+
public:
22+
const char* match_pattern_graph() const
23+
{
24+
return R"PNNXIR(7767517
25+
7 6
26+
pnnx.Input input_0 0 1 x1
27+
pnnx.Input input_1 0 1 x2
28+
prim::Constant op_0 0 1 p value=%p
29+
prim::Constant op_1 0 1 eps value=%eps
30+
prim::Constant op_2 0 1 keepdim value=%keepdim
31+
aten::pairwise_distance op_3 5 1 x1 x2 p eps keepdim out
32+
pnnx.Output output 1 0 out
33+
)PNNXIR";
34+
}
35+
36+
const char* type_str() const
37+
{
38+
return "F.pairwise_distance";
39+
}
40+
};
41+
42+
REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_pairwise_distance, 10)
43+
44+
} // namespace pnnx

tools/pnnx/tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ pnnx_add_test(F_max_pool2d)
5454
pnnx_add_test(F_max_pool3d)
5555
pnnx_add_test(F_normalize)
5656
pnnx_add_test(F_pad)
57+
pnnx_add_test(F_pairwise_distance)
5758
pnnx_add_test(F_pixel_shuffle)
5859
pnnx_add_test(F_pixel_unshuffle)
5960
pnnx_add_test(F_prelu)
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Tencent is pleased to support the open source community by making ncnn available.
2+
#
3+
# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
4+
#
5+
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6+
# in compliance with the License. You may obtain a copy of the License at
7+
#
8+
# https://opensource.org/licenses/BSD-3-Clause
9+
#
10+
# Unless required by applicable law or agreed to in writing, software distributed
11+
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12+
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
13+
# specific language governing permissions and limitations under the License.
14+
15+
import torch
16+
import torch.nn as nn
17+
import torch.nn.functional as F
18+
19+
class Model(nn.Module):
20+
def __init__(self):
21+
super(Model, self).__init__()
22+
23+
def forward(self, x, y):
24+
z1 = F.pairwise_distance(x,y,p=1,keepdim=False)
25+
z2 = F.pairwise_distance(x,y,p=2,keepdim=True)
26+
z3 = F.pairwise_distance(x,y)
27+
z4 = F.pairwise_distance(x,y,eps = 1e-3)
28+
return z1,z2,z3,z4
29+
30+
def test():
31+
net = Model()
32+
net.eval()
33+
34+
torch.manual_seed(0)
35+
x = torch.rand(12, 128, 128)
36+
y = torch.rand(12, 128, 128)
37+
38+
a0,a1,a2,a3 = net(x, y)
39+
40+
# export torchscript
41+
mod = torch.jit.trace(net, (x, y))
42+
mod.save("test_F_pairwise_distance.pt")
43+
44+
# torchscript to pnnx
45+
import os
46+
os.system("../src/pnnx test_F_pairwise_distance.pt inputshape=[12,128,128],[12,128,128]")
47+
48+
# pnnx inference
49+
import test_F_pairwise_distance_pnnx
50+
b0,b1,b2,b3 = test_F_pairwise_distance_pnnx.test_inference()
51+
52+
return torch.equal(a0,b0) and torch.equal(a1,b1) and torch.equal(a2,b2) and torch.equal(a3,b3)
53+
54+
if __name__ == "__main__":
55+
if test():
56+
exit(0)
57+
else:
58+
exit(1)

0 commit comments

Comments
 (0)