Skip to content

Commit 0415602

Browse files
pfackeldeyiannapre-commit-ci[bot]
authored
feat: add axis=None reducer specializations (#3653)
* feat: add axis=None reducer specializations * only add an axis=None specialization for sum * fix check for rectangular 1D arrays to determine if axis=None can be used * fix axis=0,1 to axis=None specialization path * style: pre-commit fixes * simplify if condition --------- Co-authored-by: Ianna Osborne <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent b42b3d8 commit 0415602

File tree

6 files changed

+101
-0
lines changed

6 files changed

+101
-0
lines changed

src/awkward/_do.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,9 +224,13 @@ def reduce(
224224
keepdims: bool = False,
225225
behavior: dict | None = None,
226226
):
227+
# store the original reducer for potential reuse later
228+
original_reducer = reducer
227229
reducer = layout.backend.prepare_reducer(reducer)
228230

229231
if axis is None:
232+
del original_reducer # not used below this point
233+
230234
parts = remove_structure(
231235
layout,
232236
flatten_records=False,
@@ -247,6 +251,16 @@ def reduce(
247251
else:
248252
(layout,) = parts
249253

254+
# Check if we're running with concrete data and if the reducer has a axis=None specialization.
255+
# If both are true, we use the specialized reducer. This allows us to use optimized implementations
256+
# from e.g. NumPy, but also make use of potentially better algorithms, i.e. Kahan summation for sum.
257+
if (
258+
layout.backend.nplike.known_data
259+
and (specialization := reducer.axis_none_reducer()) is not None
260+
):
261+
# overwrite reducer if it has an axis=None version
262+
reducer = specialization
263+
250264
starts = ak.index.Index64.zeros(1, layout.backend.nplike)
251265
parents = ak.index.Index64.zeros(layout.length, layout.backend.nplike)
252266
shifts = None
@@ -288,6 +302,19 @@ def reduce(
288302
f"(which is {depth})"
289303
)
290304

305+
# a flat array can be fully reduced with axis=None or axis=0 or axis=-1,
306+
# so we treat them as equivalent and recurse to the axis=None specialization
307+
if depth == negaxis == 1:
308+
return reduce(
309+
layout=layout,
310+
reducer=original_reducer,
311+
axis=None,
312+
mask=mask,
313+
keepdims=keepdims,
314+
behavior=behavior,
315+
)
316+
del original_reducer # not used below this point
317+
291318
starts = ak.index.Index64.zeros(1, layout.backend.nplike)
292319
parents = ak.index.Index64.zeros(layout.length, layout.backend.nplike)
293320
shifts = None

src/awkward/_nplikes/array_module.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,17 @@ def min(
665665
(x,) = maybe_materialize(x)
666666
return self._module.min(x, axis=axis, keepdims=keepdims, out=maybe_out)
667667

668+
def sum(
669+
self,
670+
x: ArrayLikeT,
671+
*,
672+
axis: ShapeItem | tuple[ShapeItem, ...] | None = None,
673+
keepdims: bool = False,
674+
maybe_out: ArrayLikeT | None = None,
675+
) -> ArrayLikeT:
676+
(x,) = maybe_materialize(x)
677+
return self._module.sum(x, axis=axis, keepdims=keepdims, out=maybe_out)
678+
668679
def max(
669680
self,
670681
x: ArrayLikeT,

src/awkward/_nplikes/cupy.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,21 @@ def min(
139139
else:
140140
return out
141141

142+
def sum(
143+
self,
144+
x: ArrayLike,
145+
*,
146+
axis: ShapeItem | tuple[ShapeItem, ...] | None = None,
147+
keepdims: bool = False,
148+
maybe_out: ArrayLike | None = None,
149+
) -> ArrayLike:
150+
(x,) = maybe_materialize(x)
151+
out = self._module.sum(x, axis=axis, out=maybe_out)
152+
if axis is None and isinstance(out, self._module.ndarray):
153+
return out.item()
154+
else:
155+
return out
156+
142157
def max(
143158
self,
144159
x: ArrayLike,

src/awkward/_nplikes/numpy_like.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,16 @@ def max(
467467
maybe_out: ArrayLikeT | None = None,
468468
) -> ArrayLikeT: ...
469469

470+
@abstractmethod
471+
def sum(
472+
self,
473+
x: ArrayLikeT,
474+
*,
475+
axis: int | tuple[int, ...] | None = None,
476+
keepdims: bool = False,
477+
maybe_out: ArrayLikeT | None = None,
478+
) -> ArrayLikeT: ...
479+
470480
@abstractmethod
471481
def count_nonzero(
472482
self, x: ArrayLikeT, *, axis: int | tuple[int, ...] | None = None

src/awkward/_nplikes/typetracer.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1646,6 +1646,16 @@ def max(
16461646
) -> TypeTracerArray:
16471647
return self.min(x, axis=axis, keepdims=keepdims, maybe_out=maybe_out)
16481648

1649+
def sum(
1650+
self,
1651+
x: TypeTracerArray,
1652+
*,
1653+
axis: int | tuple[int, ...] | None = None,
1654+
keepdims: bool = False,
1655+
maybe_out: TypeTracerArray | None = None,
1656+
) -> TypeTracerArray:
1657+
return self.min(x, axis=axis, keepdims=keepdims, maybe_out=maybe_out)
1658+
16491659
def array_str(
16501660
self,
16511661
x: TypeTracerArray,

src/awkward/_reducers.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@ def preferred_dtype(self) -> DTypeLike: ...
3333
def highlevel_function(cls):
3434
return getattr(ak.operations, cls.name)
3535

36+
@classmethod
37+
def axis_none_reducer(cls) -> Reducer | None:
38+
"""A specialized version for axis=None reductions, or None if there is none."""
39+
return None
40+
3641
@abstractmethod
3742
def apply(
3843
self,
@@ -336,6 +341,10 @@ class Sum(KernelReducer):
336341
preferred_dtype: Final = np.float64
337342
needs_position: Final = False
338343

344+
@classmethod
345+
def axis_none_reducer(cls):
346+
return AxisNoneSum()
347+
339348
def apply(
340349
self,
341350
array: ak.contents.NumpyArray,
@@ -435,6 +444,25 @@ def apply(
435444
)
436445

437446

447+
class AxisNoneSum(Sum):
448+
def apply(
449+
self,
450+
array: ak.contents.NumpyArray,
451+
parents: ak.index.Index,
452+
starts: ak.index.Index,
453+
shifts: ak.index.Index | None,
454+
outlength: ShapeItem,
455+
) -> ak.contents.NumpyArray:
456+
del parents, starts, shifts, outlength # Unused
457+
assert isinstance(array, ak.contents.NumpyArray)
458+
if array.dtype.kind == "M":
459+
raise ValueError(f"cannot compute the sum (ak.sum) of {array.dtype!r}")
460+
reduce_fn = getattr(array.backend.nplike, self.name)
461+
return ak.contents.NumpyArray(
462+
[reduce_fn(array.data, axis=None)], backend=array.backend
463+
)
464+
465+
438466
class Prod(KernelReducer):
439467
name: Final = "prod"
440468
preferred_dtype: Final = np.float64

0 commit comments

Comments
 (0)