Skip to content

Commit 10e44c1

Browse files
committed
fix: dynamic array 2d data access attribute again
1 parent d65aae1 commit 10e44c1

File tree

1 file changed

+80
-27
lines changed

1 file changed

+80
-27
lines changed

brian2/memory/cythondynamicarray.pyx

Lines changed: 80 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,31 @@ from libc.string cimport memset
1111
from libc.stdint cimport int64_t, int32_t
1212
from cython cimport view
1313
from cpython.pycapsule cimport PyCapsule_New
14+
from cpython.ref cimport PyTypeObject
1415

1516
cnp.import_array()
1617

18+
cdef extern from "numpy/ndarrayobject.h":
19+
object PyArray_NewFromDescr(PyTypeObject* subtype,
20+
cnp.PyArray_Descr* descr,
21+
int nd,
22+
cnp.npy_intp* dims,
23+
cnp.npy_intp* strides,
24+
void* data,
25+
int flags,
26+
object obj)
27+
cnp.PyArray_Descr* PyArray_DescrFromType(int)
28+
29+
cdef extern from "numpy/ndarraytypes.h":
30+
void PyArray_CLEARFLAGS(cnp.PyArrayObject *arr, int flags)
31+
enum:
32+
NPY_ARRAY_C_CONTIGUOUS
33+
NPY_ARRAY_F_CONTIGUOUS
34+
NPY_ARRAY_OWNDATA
35+
NPY_ARRAY_WRITEABLE
36+
NPY_ARRAY_ALIGNED
37+
NPY_ARRAY_WRITEBACKIFCOPY
38+
NPY_ARRAY_UPDATEIFCOPY
1739

1840
cdef extern from "dynamic_array.h":
1941
cdef cppclass DynamicArray1DCpp "DynamicArray1D"[T]:
@@ -383,41 +405,72 @@ cdef class DynamicArray2DClass:
383405
(<DynamicArray2DCpp[int64_t]*>self.thisptr).shrink(new_rows, new_cols)
384406
elif self.dtype == np.bool_:
385407
(<DynamicArray2DCpp[char]*>self.thisptr).shrink(new_rows, new_cols)
408+
386409
@property
387410
def data(self):
388-
"""Return numpy array view with proper strides"""
389-
cdef cnp.npy_intp shape[2]
390-
cdef cnp.npy_intp flat_size
391-
cdef cnp.ndarray buffer_view
411+
"""
412+
The magic getter! This creates a zero-copy NumPy 'view' of our C++ data.
413+
It's not a copy; it's a direct window into the C++ memory, which is why it's so fast.
414+
Every time our code accesses `my_array.data`, this code runs to build that view on the fly.
415+
"""
416+
# First, what's the logical shape the user sees,we get it ...
392417
cdef size_t rows = self.get_rows()
393418
cdef size_t cols = self.get_cols()
394-
cdef size_t stride = self.get_stride()
419+
# Now, the two most important pieces for our zero-copy trick:
420+
# 1. The actual memory address where our data lives in C++.
395421
cdef void* data_ptr = self.get_data_ptr()
396-
cdef size_t i, start_idx, end_idx # Loop variables
422+
# 2. The *physical* width of a row in memory. This might be wider than `cols`
423+
# if we've over-allocated space to make future growth faster.
424+
cdef size_t stride = self.get_stride()
425+
# How many bytes does one element take up? (e.g., 8 for a float64)
426+
cdef size_t itemsize = self.dtype.itemsize
397427

398-
if rows == 0 or cols == 0:
399-
return np.empty((0, 0), dtype=self.dtype)
428+
# --- Now we create the "map" that tells NumPy how to navigate our C++ memory correctly ---
400429

430+
# These are C-style arrays to hold the shape and the "stride map".
431+
cdef cnp.npy_intp shape[2]
432+
cdef cnp.npy_intp strides[2]
433+
434+
# So the shape is easy as it's just the logical dimensions.
435+
shape[0] = rows
436+
shape[1] = cols
437+
438+
# Now, the stride map. This tells NumPy how many *bytes* to jump to move through the data.
439+
# To move to the next item in the same row (j -> j+1), just jump by one item's size.
440+
strides[1] = itemsize
441+
# To move to the *next row* (i -> i+1), we have to jump over a whole physical row in memory.
442+
strides[0] = stride * itemsize
443+
444+
# We also need to describe our data type (e.g., float64) to NumPy in its native C language.
445+
cdef cnp.PyArray_Descr* descr = PyArray_DescrFromType(self.numpy_type)
446+
447+
# Now we set the permissions and properties for our numpy view
448+
# Let's start with a crucial permission: making the array writeable!
449+
# Without this, NumPy would make it read-only, and `arr[i] = x` would fail.
450+
cdef int flags = cnp.NPY_ARRAY_WRITEABLE
451+
452+
# A little optimization: if the memory is perfectly packed (no extra space in rows),
453+
# we can tell NumPy it's "C-contiguous". This can speed up some operations.
454+
if stride == cols:
455+
flags |= cnp.NPY_ARRAY_C_CONTIGUOUS
456+
457+
# Here we call the master C-API function, we give it:
458+
# the memory pointer, the shape map, the stride map, the data type, and the permissions.
459+
cdef cnp.ndarray result = <cnp.ndarray>PyArray_NewFromDescr(
460+
<PyTypeObject*>np.ndarray,
461+
descr,
462+
2,
463+
shape,
464+
strides,
465+
data_ptr,
466+
flags, # Use our flags variable
467+
None
468+
)
401469

402-
if stride ==cols:
403-
# Easy Case : buffer width = what we what
404-
shape[0] = rows
405-
shape[1] = cols
406-
return cnp.PyArray_SimpleNewFromData(2, shape , self.numpy_type,data_ptr)
407-
else:
408-
# if stride != cols , we copy data instead of using strides
409-
# Tricky case : buffer is wider than what we want , so
410-
# We just copy the parts we need for the view
411-
result = np.empty((rows,cols),dtype=self.dtype)
412-
flat_size = rows * stride
413-
buffer_view = cnp.PyArray_SimpleNewFromData(1, &flat_size, self.numpy_type, data_ptr)
414-
# Copy each row from buffer to result
415-
for i in range(rows):
416-
start_idx = i * stride
417-
end_idx = start_idx + cols
418-
result[i, :] = buffer_view[start_idx:end_idx]
419-
420-
return result
470+
# By default, NumPy assumes it owns the data and will try to free it later.
471+
# But *our* C++ vector owns it! Clearing this flag prevents a double-free, which would crash the program.
472+
cnp.PyArray_CLEARFLAGS(result, cnp.NPY_ARRAY_OWNDATA)
473+
return result
421474

422475
@property
423476
def shape(self):

0 commit comments

Comments
 (0)