You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Expose functools.reduce initializer argument to tree_util.tree_reduce (#2935)
* Expose `functools.reduce` initializer argument to `tree_util.tree_reduce`.
`functools.reduce` takes an optional `initializer` argument (default=None) which is currently not exposed by `tree_reduce'. This can be useful e.g. for computing an L2 penalty, where you would initialize with 0., and then sum the L2 for each parameter.
Example:
```
def l2_sum(total, param):
return total + jnp.sum(param**2)
tree_reduce(l2_sum, params, 0.)
```
* Only call functools.reduce with initializer when it is not None.
* Change logic to check for number of args to allow None value as initializer
* Rename seq to tree, and add tree_leaves
* Change reduce to functools.reduce.
* Make tree_reduce self-documenting
* Replace jax.tree_leaves with tree_leaves
* Update to use custom sentinel instead of optional position argument
* jax.tree_leaves -> tree_leaves
0 commit comments