Skip to content

Commit dc234b6

Browse files
authored
Expose functools.reduce initializer argument to tree_util.tree_reduce (#2935)
* Expose `functools.reduce` initializer argument to `tree_util.tree_reduce`. `functools.reduce` takes an optional `initializer` argument (default=None) which is currently not exposed by `tree_reduce'. This can be useful e.g. for computing an L2 penalty, where you would initialize with 0., and then sum the L2 for each parameter. Example: ``` def l2_sum(total, param): return total + jnp.sum(param**2) tree_reduce(l2_sum, params, 0.) ``` * Only call functools.reduce with initializer when it is not None. * Change logic to check for number of args to allow None value as initializer * Rename seq to tree, and add tree_leaves * Change reduce to functools.reduce. * Make tree_reduce self-documenting * Replace jax.tree_leaves with tree_leaves * Update to use custom sentinel instead of optional position argument * jax.tree_leaves -> tree_leaves
1 parent e4d8cac commit dc234b6

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

jax/tree_util.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,12 @@ def _replace_nones(sentinel, tree):
231231
else:
232232
return tree
233233

234-
def tree_reduce(f, tree):
235-
return functools.reduce(f, tree_leaves(tree))
234+
no_initializer = object()
235+
def tree_reduce(function, tree, initializer=no_initializer):
236+
if initializer is no_initializer:
237+
return functools.reduce(function, tree_leaves(tree))
238+
else:
239+
return functools.reduce(function, tree_leaves(tree), initializer)
236240

237241
def tree_all(tree):
238242
return all(tree_leaves(tree))

0 commit comments

Comments
 (0)