Skip to content

Commit bd450f6

Browse files
committed
Switches to pybind11 instead of nanobind
1 parent 59ca012 commit bd450f6

12 files changed

+113
-103
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ requires = [
33
"setuptools>=64",
44
"setuptools_scm>=8",
55
"scikit-build-core>=0.10",
6-
"nanobind",
6+
"pybind11",
77
"ipp-devel>=2022.1",
88
"ipp-static>=2022.1",
99
"intel-openmp",

src/cil/Binning.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424
#include "ipp.h"
2525
#include <omp.h>
2626
#include "utilities.h"
27-
#include <nanobind/nanobind.h>
28-
#include <nanobind/ndarray.h>
2927

3028
class Binner {
3129

@@ -201,6 +199,6 @@ void* Binner_new(
201199
return new Binner(shape_in.data(), shape_out.data(), pixel_index_start.data(), binning_list.data());
202200
}
203201
int Binner_bin(void* binner, DataInput data_in, DataBinned data_binned) {
204-
return ((Binner*)binner)->bin(data_in.data(), data_binned.data());
202+
return ((Binner*)binner)->bin(data_in.data(), data_binned.mutable_data());
205203
}
206204

src/cil/CMakeLists.txt

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,13 @@ set (CMAKE_CXX_STANDARD_REQUIRED ON)
2020
if(NOT DEFINED CMAKE_BUILD_TYPE)
2121
set(CMAKE_BUILD_TYPE RelWithDebInfo)
2222
endif()
23-
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
2423

25-
find_package(nanobind CONFIG REQUIRED)
24+
set(PYBIND11_FINDPYTHON ON)
25+
find_package(pybind11 CONFIG REQUIRED)
2626

2727
find_package(IPP REQUIRED CONFIG)
2828
find_package(OpenMP REQUIRED)
2929

30-
31-
3230
if (${CMAKE_CXX_COMPILER_ID} STREQUAL "GNUCC")
3331
# appends some flags
3432
add_compile_options(-ftree-vectorize -fopt-info-vec-optimized -fopt-info-vec)
@@ -52,7 +50,7 @@ else()
5250
set (OpenMP_EXE_LINKER_FLAGS ${OpenMP_C_FLAGS})
5351
endif()
5452

55-
list(APPEND cilacc_SOURCES utilities.cpp axpby.cpp FiniteDifferenceLibrary.cpp nanobind.cpp)
53+
list(APPEND cilacc_SOURCES utilities.cpp axpby.cpp FiniteDifferenceLibrary.cpp pybind11.cpp)
5654
list(APPEND cilacc_INCLUDES ${CMAKE_CURRENT_SOURCE_DIR}/include)
5755
list(APPEND cilacc_LIBRARIES ${OpenMP_EXE_LINKER_FLAGS})
5856

@@ -67,9 +65,8 @@ endif()
6765
message(STATUS ${cilacc_INCLUDES})
6866
message(STATUS ${IPP_LIBRARIES})
6967

70-
nanobind_add_module(
68+
pybind11_add_module(
7169
cilacc
72-
NB_STATIC
7370
${cilacc_SOURCES}
7471
)
7572
include_directories(cilacc PUBLIC ${cilacc_INCLUDES})

src/cil/FBP_filtering.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ int filter_projections_avh(DataFloat data, DataFloatConst filter, DataFloatConst
6161
for (j = 0; j < half_pixy; j++)
6262
{
6363
row_start = (size_t)2 * j * pix_x;
64-
out_ptr = &data.data()[proj_start + row_start];
64+
out_ptr = &data.mutable_data()[proj_start + row_start];
6565
ippsMul_32f_I(weights.data()+row_start, out_ptr, 2* pix_x);
6666
ippsSet_32fc({ 0.f,0.f }, src, width);
6767
ippsRealToCplx_32f(out_ptr, out_ptr + pix_x, src + offset, pix_x);
@@ -76,7 +76,7 @@ int filter_projections_avh(DataFloat data, DataFloatConst filter, DataFloatConst
7676
if (pix_y % 2)
7777
{
7878
row_start = (size_t)pix_y * pix_x - pix_x;
79-
out_ptr = &data.data()[proj_start + row_start];
79+
out_ptr = &data.mutable_data()[proj_start + row_start];
8080

8181
ippsMul_32f_I(weights.data() + row_start, out_ptr, pix_x);
8282
ippsSet_32fc({ 0.f,0.f }, src, width);
@@ -140,7 +140,7 @@ int filter_projections_vah(DataFloat data, DataFloatConst filter, DataFloatConst
140140
for (j = 0; j < half_proj; j++)
141141
{
142142
row_start = (size_t)2 * j * pix_x;
143-
out_ptr = &data.data()[col_start + row_start];
143+
out_ptr = &data.mutable_data()[col_start + row_start];
144144
ippsMul_32f_I(weights_ptr, out_ptr, pix_x);
145145
ippsMul_32f_I(weights_ptr, out_ptr + pix_x, pix_x);
146146

@@ -157,7 +157,7 @@ int filter_projections_vah(DataFloat data, DataFloatConst filter, DataFloatConst
157157
if (num_proj % 2)
158158
{
159159
row_start = (size_t)num_proj * pix_x - pix_x;
160-
out_ptr = &data.data()[col_start + row_start];
160+
out_ptr = &data.mutable_data()[col_start + row_start];
161161
ippsMul_32f_I(weights_ptr, out_ptr, pix_x);
162162
ippsSet_32fc({ 0.f,0.f }, src, width);
163163
ippsRealToCplx_32f(out_ptr, NULL, src + offset, pix_x);

src/cil/FiniteDifferenceLibrary.cpp

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -537,11 +537,11 @@ int fdiff_adjoint_periodic(float *outimagefull, const float *inimageXfull, const
537537
return 0;
538538
}
539539

540-
int fdiff4D(nb::ndarray<float> imagefull,
541-
nb::ndarray<float> gradCfull,
542-
nb::ndarray<float> gradZfull,
543-
nb::ndarray<float> gradYfull,
544-
nb::ndarray<float> gradXfull,
540+
int fdiff4D(DataFloat imagefull,
541+
DataFloat gradCfull,
542+
DataFloat gradZfull,
543+
DataFloat gradYfull,
544+
DataFloat gradXfull,
545545
size_t nc, size_t nz, size_t ny, size_t nx,
546546
int boundary, int direction,
547547
int nThreads)
@@ -553,25 +553,25 @@ int fdiff_adjoint_periodic(float *outimagefull, const float *inimageXfull, const
553553
if (boundary)
554554
{
555555
if (direction)
556-
status = fdiff_direct_periodic(imagefull.data(), gradXfull.data(), gradYfull.data(), gradZfull.data(), gradCfull.data(), nx, ny, nz, nc);
556+
status = fdiff_direct_periodic(imagefull.data(), gradXfull.mutable_data(), gradYfull.mutable_data(), gradZfull.mutable_data(), gradCfull.mutable_data(), nx, ny, nz, nc);
557557
else
558-
status = fdiff_adjoint_periodic(imagefull.data(), gradXfull.data(), gradYfull.data(), gradZfull.data(), gradCfull.data(), nx, ny, nz, nc);
558+
status = fdiff_adjoint_periodic(imagefull.mutable_data(), gradXfull.data(), gradYfull.data(), gradZfull.data(), gradCfull.data(), nx, ny, nz, nc);
559559
}
560560
else
561561
{
562562
if (direction)
563-
status = fdiff_direct_neumann(imagefull.data(), gradXfull.data(), gradYfull.data(), gradZfull.data(), gradCfull.data(), nx, ny, nz, nc);
563+
status = fdiff_direct_neumann(imagefull.data(), gradXfull.mutable_data(), gradYfull.mutable_data(), gradZfull.mutable_data(), gradCfull.mutable_data(), nx, ny, nz, nc);
564564
else
565-
status = fdiff_adjoint_neumann(imagefull.data(), gradXfull.data(), gradYfull.data(), gradZfull.data(), gradCfull.data(), nx, ny, nz, nc);
565+
status = fdiff_adjoint_neumann(imagefull.mutable_data(), gradXfull.data(), gradYfull.data(), gradZfull.data(), gradCfull.data(), nx, ny, nz, nc);
566566
}
567567

568568
omp_set_num_threads(nThreads_initial);
569569
return status;
570570
}
571-
int fdiff3D(nb::ndarray<float> imagefull,
572-
nb::ndarray<float> gradZfull,
573-
nb::ndarray<float> gradYfull,
574-
nb::ndarray<float> gradXfull,
571+
int fdiff3D(DataFloat imagefull,
572+
DataFloat gradZfull,
573+
DataFloat gradYfull,
574+
DataFloat gradXfull,
575575
size_t nz, size_t ny, size_t nx,
576576
int boundary, int direction,
577577
int nThreads)
@@ -583,24 +583,24 @@ int fdiff3D(nb::ndarray<float> imagefull,
583583
if (boundary)
584584
{
585585
if (direction)
586-
status = fdiff_direct_periodic(imagefull.data(), gradXfull.data(), gradYfull.data(), gradZfull.data(), NULL, nx, ny, nz, 1);
586+
status = fdiff_direct_periodic(imagefull.data(), gradXfull.mutable_data(), gradYfull.mutable_data(), gradZfull.mutable_data(), NULL, nx, ny, nz, 1);
587587
else
588-
status = fdiff_adjoint_periodic(imagefull.data(), gradXfull.data(), gradYfull.data(), gradZfull.data(), NULL, nx, ny, nz, 1);
588+
status = fdiff_adjoint_periodic(imagefull.mutable_data(), gradXfull.data(), gradYfull.data(), gradZfull.data(), NULL, nx, ny, nz, 1);
589589
}
590590
else
591591
{
592592
if (direction)
593-
status = fdiff_direct_neumann(imagefull.data(), gradXfull.data(), gradYfull.data(), gradZfull.data(), NULL, nx, ny, nz, 1);
593+
status = fdiff_direct_neumann(imagefull.data(), gradXfull.mutable_data(), gradYfull.mutable_data(), gradZfull.mutable_data(), NULL, nx, ny, nz, 1);
594594
else
595-
status = fdiff_adjoint_neumann(imagefull.data(), gradXfull.data(), gradYfull.data(), gradZfull.data(), NULL, nx, ny, nz, 1);
595+
status = fdiff_adjoint_neumann(imagefull.mutable_data(), gradXfull.data(), gradYfull.data(), gradZfull.data(), NULL, nx, ny, nz, 1);
596596
}
597597

598598
omp_set_num_threads(nThreads_initial);
599599
return status;
600600
}
601-
int fdiff2D(nb::ndarray<float> imagefull,
602-
nb::ndarray<float> gradYfull,
603-
nb::ndarray<float> gradXfull,
601+
int fdiff2D(DataFloat imagefull,
602+
DataFloat gradYfull,
603+
DataFloat gradXfull,
604604
size_t ny, size_t nx,
605605
int boundary, int direction,
606606
int nThreads)
@@ -612,16 +612,16 @@ int fdiff2D(nb::ndarray<float> imagefull,
612612
if (boundary)
613613
{
614614
if (direction)
615-
status = fdiff_direct_periodic(imagefull.data(), gradXfull.data(), gradYfull.data(), NULL, NULL, nx, ny, 1, 1);
615+
status = fdiff_direct_periodic(imagefull.data(), gradXfull.mutable_data(), gradYfull.mutable_data(), NULL, NULL, nx, ny, 1, 1);
616616
else
617-
status = fdiff_adjoint_periodic(imagefull.data(), gradXfull.data(), gradYfull.data(), NULL, NULL, nx, ny, 1, 1);
617+
status = fdiff_adjoint_periodic(imagefull.mutable_data(), gradXfull.data(), gradYfull.data(), NULL, NULL, nx, ny, 1, 1);
618618
}
619619
else
620620
{
621621
if (direction)
622-
status = fdiff_direct_neumann(imagefull.data(), gradXfull.data(), gradYfull.data(), NULL, NULL, nx, ny, 1, 1);
622+
status = fdiff_direct_neumann(imagefull.data(), gradXfull.mutable_data(), gradYfull.mutable_data(), NULL, NULL, nx, ny, 1, 1);
623623
else
624-
status = fdiff_adjoint_neumann(imagefull.data(), gradXfull.data(), gradYfull.data(), NULL, NULL, nx, ny, 1, 1);
624+
status = fdiff_adjoint_neumann(imagefull.mutable_data(), gradXfull.data(), gradYfull.data(), NULL, NULL, nx, ny, 1, 1);
625625
}
626626

627627
omp_set_num_threads(nThreads_initial);

src/cil/axpby.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -121,13 +121,13 @@ int saxpby(DataFloatInput x, DataFloatInput y,
121121
threads_setup(nThreads, &nThreads_initial);
122122

123123
if (type_a == 0 && type_b == 0)
124-
saxpby_asbs(x.data(), y.data(), out.data(), *a.data(), *b.data(), size, nThreads);
124+
saxpby_asbs(x.data(), y.data(), out.mutable_data(), *a.data(), *b.data(), size, nThreads);
125125
else if (type_a == 1 && type_b == 1)
126-
saxpby_avbv(x.data(), y.data(), out.data(), a.data(), b.data(), size, nThreads);
126+
saxpby_avbv(x.data(), y.data(), out.mutable_data(), a.data(), b.data(), size, nThreads);
127127
else if (type_a == 0 && type_b == 1)
128-
saxpby_asbv(x.data(), y.data(), out.data(), *a.data(), b.data(), size, nThreads);
128+
saxpby_asbv(x.data(), y.data(), out.mutable_data(), *a.data(), b.data(), size, nThreads);
129129
else if (type_a == 1 && type_b == 0)
130-
saxpby_asbv(y.data(), x.data(), out.data(), *b.data(), a.data(), size, nThreads);
130+
saxpby_asbv(y.data(), x.data(), out.mutable_data(), *b.data(), a.data(), size, nThreads);
131131

132132
omp_set_num_threads(nThreads_initial);
133133

@@ -149,13 +149,13 @@ int daxpby(DataDoubleInput x, DataDoubleInput y,
149149
threads_setup(nThreads, &nThreads_initial);
150150

151151
if (type_a == 0 && type_b == 0)
152-
daxpby_asbs(x.data(), y.data(), out.data(), *a.data(), *b.data(), size, nThreads);
152+
daxpby_asbs(x.data(), y.data(), out.mutable_data(), *a.data(), *b.data(), size, nThreads);
153153
else if (type_a == 1 && type_b == 1)
154-
daxpby_avbv(x.data(), y.data(), out.data(), a.data(), b.data(), size, nThreads);
154+
daxpby_avbv(x.data(), y.data(), out.mutable_data(), a.data(), b.data(), size, nThreads);
155155
else if (type_a == 0 && type_b == 1)
156-
daxpby_asbv(x.data(), y.data(), out.data(), *a.data(), b.data(), size, nThreads);
156+
daxpby_asbv(x.data(), y.data(), out.mutable_data(), *a.data(), b.data(), size, nThreads);
157157
else if (type_a == 1 && type_b == 0)
158-
daxpby_asbv(y.data(), x.data(), out.data(), *b.data(), a.data(), size, nThreads);
158+
daxpby_asbv(y.data(), x.data(), out.mutable_data(), *b.data(), a.data(), size, nThreads);
159159

160160
omp_set_num_threads(nThreads_initial);
161161

src/cil/include/Binning.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,15 @@
1717
// CIL Developers, listed at: https://github.com/TomographicImaging/CIL/blob/master/NOTICE.txt
1818

1919
#include <cstddef>
20-
#include <nanobind/nanobind.h>
21-
#include <nanobind/ndarray.h>
20+
#include <pybind11/pybind11.h>
21+
#include <pybind11/numpy.h>
2222

23-
namespace nb = nanobind;
23+
namespace py = pybind11;
2424

25-
using Shape = nb::ndarray<const size_t>;
26-
using DataInput = nb::ndarray<const float>;
27-
using DataBinned = nb::ndarray<float>;
25+
26+
using Shape = py::array_t<const size_t>;
27+
using DataInput = py::array_t<const float>;
28+
using DataBinned = py::array_t<float>;
2829

2930
void Binner_delete(void* binner);
3031
void* Binner_new(Shape shape_in,

src/cil/include/FBP_filtering.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,13 @@
2424
#include <omp.h>
2525
#include <random>
2626
#include "utilities.h"
27-
#include <nanobind/ndarray.h>
27+
#include <pybind11/pybind11.h>
28+
#include <pybind11/numpy.h>
2829

29-
namespace nb = nanobind;
30+
namespace py = pybind11;
3031

31-
using DataFloatConst = nb::ndarray<const float>;
32-
using DataFloat = nb::ndarray<float>;
32+
using DataFloatConst = py::array_t<const float>;
33+
using DataFloat = py::array_t<float>;
3334

3435
int filter_projections_avh(DataFloat data, DataFloatConst filter, DataFloatConst weights, int order, long num_proj, long pix_y, long pix_x);
3536
int filter_projections_vah(DataFloat data, DataFloatConst filter, DataFloatConst weights, int order, long pix_y, long num_proj, long pix_x);

src/cil/include/FiniteDifferenceLibrary.h

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,16 @@
2222
#include <omp.h>
2323
#include "utilities.h"
2424

25-
#include <nanobind/ndarray.h>
26-
namespace nb = nanobind;
25+
#include <pybind11/pybind11.h>
26+
#include <pybind11/numpy.h>
27+
28+
namespace py = pybind11;
29+
30+
using DataFloat = py::array_t<float>;
31+
32+
int fdiff_direct_neumann(const float *inimagefull, float *outimageXfull, float *outimageYfull, float *outimageZfull, float *outimageCfull, size_t nx, size_t ny, size_t nz, size_t nc);
33+
int fdiff_direct_periodic(const float *inimagefull, float *outimageXfull, float *outimageYfull, float *outimageZfull, float *outimageCfull, size_t nx, size_t ny, size_t nz, size_t nc);
34+
namespace py = pybind11;
2735

2836

2937
int fdiff_direct_neumann(const float *inimagefull, float *outimageXfull, float *outimageYfull, float *outimageZfull, float *outimageCfull, size_t nx, size_t ny, size_t nz, size_t nc);
@@ -33,24 +41,24 @@ int fdiff_adjoint_periodic(float *outimagefull, const float *inimageXfull, const
3341

3442

3543
int openMPtest(int nThreads);
36-
int fdiff4D(nb::ndarray<float> imagefull,
37-
nb::ndarray<float> gradCfull,
38-
nb::ndarray<float> gradZfull,
39-
nb::ndarray<float> gradYfull,
40-
nb::ndarray<float> gradXfull,
44+
int fdiff4D(DataFloat imagefull,
45+
DataFloat gradCfull,
46+
DataFloat gradZfull,
47+
DataFloat gradYfull,
48+
DataFloat gradXfull,
4149
size_t nc, size_t nz, size_t ny,
4250
size_t nx, int boundary, int direction,
4351
int nThreads);
44-
int fdiff3D(nb::ndarray<float> imagefull,
45-
nb::ndarray<float> gradZfull,
46-
nb::ndarray<float> gradYfull,
47-
nb::ndarray<float> gradXfull,
52+
int fdiff3D(DataFloat imagefull,
53+
DataFloat gradZfull,
54+
DataFloat gradYfull,
55+
DataFloat gradXfull,
4856
size_t nz, size_t ny, size_t nx,
4957
int boundary, int direction,
5058
int nThreads);
51-
int fdiff2D(nb::ndarray<float> imagefull,
52-
nb::ndarray<float> gradYfull,
53-
nb::ndarray<float> gradXfull,
59+
int fdiff2D(DataFloat imagefull,
60+
DataFloat gradYfull,
61+
DataFloat gradXfull,
5462
size_t ny, size_t nx,
5563
int boundary, int direction,
5664
int nThreads);

src/cil/include/axpby.h

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,19 @@
2121
#include <stdio.h>
2222
#include <omp.h>
2323
#include "utilities.h"
24-
#include <nanobind/ndarray.h>
24+
#include <pybind11/pybind11.h>
25+
#include <pybind11/numpy.h>
26+
27+
namespace py = pybind11;
2528

26-
namespace nb = nanobind;
2729

2830
using int64 = long long;
2931

30-
using DataFloatInput = nb::ndarray<const float>;
31-
using DataDoubleInput = nb::ndarray<const double>;
32+
using DataFloatInput = py::array_t<const float>;
33+
using DataDoubleInput = py::array_t<const double>;
3234

33-
using DataFloatOutput = nb::ndarray<float>;
34-
using DataDoubleOutput = nb::ndarray<double>;
35+
using DataFloatOutput = py::array_t<float>;
36+
using DataDoubleOutput = py::array_t<double>;
3537

3638

3739
int saxpby_asbs(const float * x, const float * y, float * out, float a, float b, int64 size, int nThreads);
@@ -42,7 +44,7 @@ int daxpby_avbv(const double * x, const double * y, double * out, const double *
4244
int daxpby_asbv(const double * x, const double * y, double * out, double a, const double * b, int64 size, int nThreads);
4345

4446
int saxpby(DataFloatInput x, DataFloatInput y,
45-
DataFloatOutput out,
47+
DataFloatOutput out,
4648
DataFloatInput a, int type_a,
4749
DataFloatInput b, int type_b,
4850
int64 size, int nThreads);

0 commit comments

Comments
 (0)