Skip to content

Commit 19d92c5

Browse files
committed
Update cpp of pytorch
1 parent c2df691 commit 19d92c5

File tree

2 files changed

+50
-12
lines changed

2 files changed

+50
-12
lines changed

pytorch/cpp/octree_key.cpp

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,11 @@ Tensor octree_xyz2key(Tensor xyz, int depth) {
5252

5353
Tensor key = torch::zeros_like(xyz);
5454
auto ptr_out = key.data_ptr<int64_t>();
55-
xyz2key_gpu((uintk*)ptr_out, (uintk*)ptr_in, num, depth);
55+
if (key.is_cuda()) {
56+
xyz2key_gpu((uintk*)ptr_out, (uintk*)ptr_in, num, depth);
57+
} else {
58+
xyz2key_cpu((uintk*)ptr_out, (uintk*)ptr_in, num, depth);
59+
}
5660
return key;
5761
}
5862

@@ -64,7 +68,11 @@ Tensor octree_key2xyz(Tensor key, int depth) {
6468

6569
Tensor xyz = torch::zeros_like(key);
6670
auto ptr_out = xyz.data_ptr<int64_t>();
67-
key2xyz_gpu((uintk*)ptr_out, (uintk*)ptr_in, num, depth);
71+
if (key.is_cuda()) {
72+
key2xyz_gpu((uintk*)ptr_out, (uintk*)ptr_in, num, depth);
73+
} else {
74+
key2xyz_cpu((uintk*)ptr_out, (uintk*)ptr_in, num, depth);
75+
}
6876
return xyz;
6977
}
7078

@@ -91,12 +99,13 @@ Tensor octree_search_key(Tensor key, Tensor octree, int depth, bool key_is_xyz,
9199
const uintk* des_key = octree_.key_gpu(depth);
92100

93101
Tensor key_tmp;
94-
if (nempty) { // Search the non-empty octree nodes only
102+
if (nempty) { // Search the non-empty octree nodes only
95103
int top_h = des_h; // cache old des_h
96-
des_h = octree_.info().node_num_nempty(depth); // update des_h
104+
des_h = octree_.info().node_num_nempty(depth); // update des_h
97105
key_tmp = torch::zeros({des_h}, options);
98106
int64_t* tmp = key_tmp.data_ptr<int64_t>();
99-
pad_backward_gpu((uintk*)tmp, des_h, 1, des_key, top_h, octree_.children_gpu(depth));
107+
pad_backward_gpu((uintk*)tmp, des_h, 1, des_key, top_h,
108+
octree_.children_gpu(depth));
100109
des_key = (const uintk*)tmp;
101110
}
102111

pytorch/cpp/octree_property.cpp

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,20 @@ Tensor octree_property_gpu(Tensor octree_in, string property, int depth) {
2525
int channel = octree_.info().channel(OctreeInfo::kKey); // = 1
2626
int total_num = channel * nnum;
2727
data_out = torch::zeros({total_num}, options.dtype(torch::kInt64));
28-
memcpy_gpu(total_num, ptr, (uintk*)data_out.data_ptr<int64_t>());
28+
uintk* des_ptr = (uintk*)data_out.data_ptr<int64_t>();
29+
if (octree_.info().is_key2xyz()) {
30+
if (depth > 0) {
31+
xyz2key_gpu(des_ptr, ptr, total_num, depth);
32+
} else {
33+
for (int d = 1; d < octree_depth + 1; d++) {
34+
int nnum_d = octree_.info().node_num(d);
35+
int ncum_d = octree_.info().node_num_cum(d);
36+
xyz2key_gpu(des_ptr + ncum_d, ptr + ncum_d, nnum_d, d);
37+
}
38+
}
39+
} else {
40+
memcpy_gpu(total_num, ptr, des_ptr);
41+
}
2942
}
3043

3144
else if (property == "xyz") {
@@ -78,7 +91,8 @@ Tensor octree_property_gpu(Tensor octree_in, string property, int depth) {
7891
CHECK(feature_ptr != nullptr) << "The features do not exist: d = " << depth;
7992
int channel = octree_.info().channel(OctreeInfo::kFeature);
8093
int total_num = channel * nnum;
81-
data_out = torch::zeros({1, channel, nnum, 1}, options.dtype(torch::kFloat32));
94+
data_out =
95+
torch::zeros({1, channel, nnum, 1}, options.dtype(torch::kFloat32));
8296
memcpy_gpu(total_num, feature_ptr, data_out.data_ptr<float>());
8397
}
8498

@@ -137,7 +151,7 @@ Tensor octree_property_gpu(Tensor octree_in, string property, int depth) {
137151
memcpy_gpu(1, &full_depth, data_out.data_ptr<int>());
138152
}
139153

140-
else{
154+
else {
141155
LOG(FATAL) << "Unsupport octree property: " << property;
142156
}
143157

@@ -161,7 +175,21 @@ Tensor octree_property_cpu(Tensor octree_in, string property, int depth) {
161175
int channel = octree_.info().channel(OctreeInfo::kKey); // = 1
162176
int total_num = channel * nnum;
163177
data_out = torch::zeros({total_num}, options.dtype(torch::kInt64));
164-
memcpy_cpu(total_num, ptr, (uintk*)data_out.data_ptr<int64_t>());
178+
uintk* des_ptr = (uintk*)data_out.data_ptr<int64_t>();
179+
if (octree_.info().is_key2xyz()) {
180+
if (depth > 0) {
181+
xyz2key_cpu(des_ptr, ptr, total_num, depth);
182+
} else {
183+
for (int d = 1; d < octree_depth + 1; d++) {
184+
int nnum_d = octree_.info().node_num(d);
185+
int ncum_d = octree_.info().node_num_cum(d);
186+
xyz2key_cpu(des_ptr + ncum_d, ptr + ncum_d, nnum_d, d);
187+
}
188+
}
189+
} else {
190+
memcpy_cpu(total_num, ptr, des_ptr);
191+
}
192+
165193
}
166194

167195
else if (property == "xyz") {
@@ -214,7 +242,8 @@ Tensor octree_property_cpu(Tensor octree_in, string property, int depth) {
214242
CHECK(feature_ptr != nullptr) << "The features do not exist: d = " << depth;
215243
int channel = octree_.info().channel(OctreeInfo::kFeature);
216244
int total_num = channel * nnum;
217-
data_out = torch::zeros({1, channel, nnum, 1}, options.dtype(torch::kFloat32));
245+
data_out =
246+
torch::zeros({1, channel, nnum, 1}, options.dtype(torch::kFloat32));
218247
memcpy_cpu(total_num, feature_ptr, data_out.data_ptr<float>());
219248
}
220249

@@ -273,7 +302,7 @@ Tensor octree_property_cpu(Tensor octree_in, string property, int depth) {
273302
memcpy_cpu(1, &full_depth, data_out.data_ptr<int>());
274303
}
275304

276-
else{
305+
else {
277306
LOG(FATAL) << "Unsupport octree property: " << property;
278307
}
279308

@@ -293,7 +322,7 @@ Tensor octree_set_property_gpu(Tensor octree_in, Tensor data_in, int depth) {
293322
data_in = data_in.contiguous();
294323
CHECK_EQ(count, data_in.numel()) << "Wrong Property Size";
295324
memcpy_gpu(count, data_in.data_ptr<float>(), property_ptr);
296-
325+
297326
return octree_out;
298327
}
299328

0 commit comments

Comments
 (0)