Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ These are the release notes for JAX.
### Breaking changes

* The minimum jaxlib version is now 0.1.38.
* Simplified `Jaxpr` by removing the `Jaxpr.freevars` and changing the
representation of `Jaxpr.bound_subjaxprs` to drop the environment values.

### New features

Expand Down
4 changes: 0 additions & 4 deletions docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@
"invars: [a]\n",
"outvars: [b]\n",
"constvars: []\n",
"freevars: []\n",
"equation: [a, 1] add [b] {}\n",
"\n",
"jaxpr: { lambda ; ; a.\n",
Expand All @@ -161,7 +160,6 @@
"invars: [a, b, c]\n",
"outvars: [g, c]\n",
"constvars: [f]\n",
"freevars: []\n",
"equation: [a, c] dot_general [d] {'dimension_numbers': (((1,), (0,)), ((), ())), 'precision': None}\n",
"equation: [d, b] add [e] {}\n",
"equation: [e, f] add [g] {}\n",
Expand All @@ -182,7 +180,6 @@
" print(\"invars:\", jaxpr.invars)\n",
" print(\"outvars:\", jaxpr.outvars)\n",
" print(\"constvars:\", jaxpr.constvars)\n",
" print(\"freevars:\", jaxpr.freevars)\n",
" for eqn in jaxpr.eqns:\n",
" print(\"equation:\", eqn.invars, eqn.primitive, eqn.outvars, eqn.params)\n",
" print()\n",
Expand Down Expand Up @@ -213,7 +210,6 @@
"* `jaxpr.invars` - the `invars` of a Jaxpr are a list of the input variables to Jaxpr, analogous to arguments in Python functions\n",
"* `jaxpr.outvars` - the `outvars` of a Jaxpr are the variables that are returned by the Jaxpr. Every Jaxpr has multiple outputs.\n",
"* `jaxpr.constvars` - the `constvars` are a list of variables that are also inputs to the Jaxpr, but correspond to constants from the trace (we'll go over these in more detail later)\n",
"* `jaxpr.freevars` - these can arise when nesting `jit` and `pmap` transformations; we won't worry about them in this colab.\n",
"* `jaxpr.eqns` - a list of equations, which are essentially let-bindings. Each equation is list of input variables, a list of output variables, and a *primitive*, which is used to evaluate inputs to produce outputs. Each equation also has a `params`, a dictionary of parameters.\n",
"\n",
"All together, a Jaxpr encapsulates a simple program that can be evaluated with inputs to produce an output. We'll go over how exactly to do this later. The important thing to note now is that a Jaxpr is a data structure that can be manipulated and evaluated in whatever way we want."
Expand Down
7 changes: 3 additions & 4 deletions jax/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def computation_maker(*args, **kwargs):
xla_consts = map(c.Constant, consts)
xla_args = xla._xla_callable_args(c, avals, tuple_args)
outs = xla.jaxpr_subcomp(
c, jaxpr, backend, axis_env_, xla_consts, (),
c, jaxpr, backend, axis_env_, xla_consts,
extend_name_stack(wrap_name(fun_name, 'xla_computation')), *xla_args)
return c.Build(c.Tuple(*outs))
return computation_maker
Expand Down Expand Up @@ -1211,7 +1211,7 @@ def fun(*tangents):
"the original primal values.")
raise ValueError(msg)
dummy = (core.unit,) * len(tangents)
out = eval_jaxpr(jaxpr, consts, (), *(dummy + tangents))
out = eval_jaxpr(jaxpr, consts, *(dummy + tangents))
tangents_out = out[len(out)//2:]
return tuple(map(pe.merge_pvals, tangents_out, out_pvals))

Expand Down Expand Up @@ -1491,7 +1491,7 @@ def custom_transforms(fun):

def fun_impl(*args, **params):
consts, args = split_list(args, [params['num_consts']])
return core.eval_jaxpr(params['jaxpr'], consts, (), *args)
return core.eval_jaxpr(params['jaxpr'], consts, *args)
fun_p.def_impl(fun_impl)

def fun_jvp(primals, tangents, **params):
Expand Down Expand Up @@ -1918,7 +1918,6 @@ def jaxpr_to_graphviz(jaxpr, consts):
fragment = []

fragment.extend(map(invar_node, jaxpr.invars, jaxpr.invars))
fragment.extend(map(freevar_node, jaxpr.freevars, jaxpr.freevars))
fragment.extend(map(constant_node, jaxpr.constvars, consts))

for eqn in jaxpr.eqns:
Expand Down
38 changes: 20 additions & 18 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,18 @@
# -------------------- jaxprs --------------------

class Jaxpr(object):
def __init__(self, constvars, freevars, invars, outvars, eqns):
def __init__(self, constvars, invars, outvars, eqns):
"""
Params:
constvars: list of variables introduced for constants (either literals
in the Python program, or the result of constant folding during the
generation of the Jaxpr). Array constants are replaced with such variables
while scalar constants are kept inline.
invars: list of input variables. Together, `constvars` and `invars` are
the inputs to the Jaxpr.
outvars: list of output variables.
eqns: list of equations."""
self.constvars = list(constvars)
self.freevars = list(freevars)
self.invars = list(invars)
self.outvars = list(outvars)
self.eqns = list(eqns)
Expand All @@ -56,7 +65,6 @@ def __init__(self, jaxpr, literals, in_avals, out_avals):
assert len(in_avals) == len(jaxpr.invars)
assert all(isinstance(aval, AbstractValue) for aval in in_avals)
assert all(isinstance(aval, AbstractValue) for aval in out_avals)
assert not jaxpr.freevars

self.jaxpr = jaxpr
self.literals = list(literals)
Expand All @@ -73,7 +81,7 @@ def __str__(self):

@curry
def jaxpr_as_fun(typed_jaxpr, *args):
return eval_jaxpr(typed_jaxpr.jaxpr, typed_jaxpr.literals, (), *args)
return eval_jaxpr(typed_jaxpr.jaxpr, typed_jaxpr.literals, *args)


JaxprEqn = namedtuple('JaxprEqn', ['invars', 'outvars', 'primitive',
Expand Down Expand Up @@ -186,7 +194,7 @@ def abstract_eval(self, *args, **kwargs):
# -------------------- lifting --------------------


def eval_jaxpr(jaxpr, consts, freevar_vals, *args):
def eval_jaxpr(jaxpr, consts, *args):
def read(v):
if type(v) is Literal:
return v.val
Expand All @@ -200,12 +208,10 @@ def write(v, val):
write(unitvar, unit)
map(write, jaxpr.constvars, consts)
map(write, jaxpr.invars, args)
map(write, jaxpr.freevars, freevar_vals)
for eqn in jaxpr.eqns:
in_vals = map(read, eqn.invars)
subfuns = [partial(eval_jaxpr, subjaxpr, map(read, const_bindings),
map(read, freevar_bindings))
for subjaxpr, const_bindings, freevar_bindings
subfuns = [partial(eval_jaxpr, subjaxpr, map(read, const_bindings))
for subjaxpr, const_bindings
in eqn.bound_subjaxprs]
subfuns = map(lu.wrap_init, subfuns)
ans = eqn.primitive.bind(*(subfuns + in_vals), **eqn.params)
Expand Down Expand Up @@ -638,12 +644,10 @@ def write_env(env, v):

write(unitvar)
map(write, jaxpr.constvars)
map(write, jaxpr.freevars)
map(write, jaxpr.invars)
for eqn in jaxpr.eqns:
map(read, eqn.invars)
for subjaxpr, constvars, freevars in eqn.bound_subjaxprs:
map(read, freevars)
for subjaxpr, constvars in eqn.bound_subjaxprs:
map(read, constvars)
check_jaxpr(subjaxpr)
map(write, eqn.outvars)
Expand All @@ -662,19 +666,17 @@ def pp_eqn(eqn):
lhs = pp_vars(eqn.outvars)
pp_subexpr = pp('')
if eqn.bound_subjaxprs:
for subjaxpr, const_vars, bound_vars in eqn.bound_subjaxprs:
for subjaxpr, const_vars in eqn.bound_subjaxprs:
pp_subexpr = pp_subexpr + (
pp_jaxpr(subjaxpr).indent(2)
>> pp(' [ {} ; {} ]'.format(pp_vars(const_vars),
pp_vars(bound_vars))))
>> pp(' [ {} ]'.format(pp_vars(const_vars))))
return (pp('{} = '.format(lhs)) >>
pp(eqn.primitive.name) >> pp_kv_pairs(sorted(eqn.params.items()))
>> pp(' ') >> pp(pp_vars(eqn.invars))) + pp_subexpr

def pp_jaxpr(jaxpr):
return (pp('{{ lambda {} ; {} ; {}.'.format(pp_vars(jaxpr.constvars),
pp_vars(jaxpr.freevars),
pp_vars(jaxpr.invars))) +
return (pp('{{ lambda {} ; {}.'.format(pp_vars(jaxpr.constvars),
pp_vars(jaxpr.invars))) +
((pp('let ') >>
vcat(map(pp_eqn, jaxpr.eqns))) +
pp('in {} }}'.format(jaxpr.outvars))).indent(2))
66 changes: 34 additions & 32 deletions jax/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def vjp_(*cts):
cts = tuple(map(ignore_consts, cts, pvals))
dummy_primals_and_cts = (core.unit,) * len(cts) + cts
dummy_args = (undefined_primal,) * len(jaxpr.invars)
_, arg_cts = backward_pass(jaxpr, consts, (), dummy_args, dummy_primals_and_cts)
arg_cts = backward_pass(jaxpr, consts, dummy_args, dummy_primals_and_cts)
arg_cts = arg_cts[len(primals):]
return map(instantiate_zeros, primals, arg_cts)

Expand All @@ -137,9 +137,9 @@ def unpair_pval(pval):
aval_1, aval_2 = aval
return (aval_1, const_1), (aval_2, const_2)

def backward_pass(jaxpr, consts, freevar_vals, args, cotangents_in):
def backward_pass(jaxpr: core.Jaxpr, consts, args, cotangents_in):
if all(ct is zero for ct in cotangents_in):
return [zero] * len(jaxpr.freevars), [zero] * len(jaxpr.invars)
return [zero] * len(jaxpr.invars)

def write_cotangent(v, ct):
# assert v not in primal_env
Expand All @@ -162,7 +162,6 @@ def write_primal(v, val):
primal_env = {}
write_primal(core.unitvar, core.unit)
map(write_primal, jaxpr.constvars, consts)
map(write_primal, jaxpr.freevars, freevar_vals)
map(write_primal, jaxpr.invars, args)

def is_linear(var):
Expand All @@ -184,22 +183,22 @@ def is_linear(var):
else:
write_primal(eqn.outvars[0], ans)
else:
(subjaxpr, const_vars, bound_vars), = eqn.bound_subjaxprs
(subjaxpr, const_vars), = eqn.bound_subjaxprs
assert not any(is_linear(v) for v in const_vars)
if any(is_linear(v) for v in it.chain(eqn.invars, bound_vars)):
if any(is_linear(v) for v in eqn.invars):
linear_eqns.append(eqn)
elif eqn.primitive is not pe.remat_call_p:
ans = _eval_subjaxpr_primals(
eqn.primitive, subjaxpr, map(read_primal, const_vars),
map(read_primal, bound_vars), map(read_primal, eqn.invars), eqn.params)
map(read_primal, eqn.invars), eqn.params)
map(write_primal, eqn.outvars, ans)

# we special-case remat_call here because it can be mixed linear /
# nonlinear, so we always evaluate it even if it has a linear part
if eqn.primitive is pe.remat_call_p:
ans = _eval_subjaxpr_primals(
eqn.primitive, subjaxpr, map(read_primal, const_vars),
map(read_primal, bound_vars), map(read_primal, eqn.invars), eqn.params)
map(read_primal, eqn.invars), eqn.params)
map(write_primal, eqn.outvars, ans)

ct_env = {}
Expand All @@ -211,29 +210,26 @@ def is_linear(var):
else:
cts_in, = map(read_cotangent, eqn.outvars)
if eqn.bound_subjaxprs:
(subjaxpr, const_vars, bound_vars), = eqn.bound_subjaxprs
(subjaxpr, const_vars), = eqn.bound_subjaxprs
sub_consts = map(read_primal, const_vars)
sub_freevar_vals = map(read_primal, bound_vars)
ct_free_vars_out, cts_out = get_primitive_transpose(eqn.primitive)(
eqn.params, subjaxpr, sub_consts, sub_freevar_vals, invals, cts_in)
map(write_cotangent, bound_vars, ct_free_vars_out)
cts_out = get_primitive_transpose(eqn.primitive)(
eqn.params, subjaxpr, sub_consts, invals, cts_in)
else:
cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals, **eqn.params)
cts_out = [zero] * len(eqn.invars) if cts_out is zero else cts_out
map(write_cotangent, eqn.invars, cts_out)

freevar_cts = map(read_cotangent, jaxpr.freevars)
cotangents_out = map(read_cotangent, jaxpr.invars)
return freevar_cts, cotangents_out
return cotangents_out

def _eval_subjaxpr_primals(prim, jaxpr, consts, freevar_vals, in_vals, params):
all_args, in_tree_def = tree_flatten((consts, freevar_vals, in_vals))
def _eval_subjaxpr_primals(prim, jaxpr, consts, in_vals, params):
all_args, in_tree_def = tree_flatten((consts, in_vals))
fun = lu.hashable_partial(lu.wrap_init(_eval_primals), jaxpr)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
out_flat = prim.bind(fun, *all_args, **params)
return tree_unflatten(out_tree(), out_flat)

def _eval_primals(jaxpr, consts, freevar_vals, args):
def _eval_primals(jaxpr, consts, args):
primal_env = {}

def read_primal(v):
Expand All @@ -254,7 +250,6 @@ def is_linear(var):

write_primal(core.unitvar, core.unit)
map(write_primal, jaxpr.constvars, consts)
map(write_primal, jaxpr.freevars, freevar_vals)
map(write_primal, jaxpr.invars, args)
for eqn in jaxpr.eqns:
if not eqn.bound_subjaxprs:
Expand All @@ -266,13 +261,13 @@ def is_linear(var):
else:
write_primal(eqn.outvars[0], ans)
else:
(subjaxpr, const_vars, bound_vars), = eqn.bound_subjaxprs
(subjaxpr, const_vars), = eqn.bound_subjaxprs
assert not any(is_linear(v) for v in const_vars)
if (eqn.primitive is pe.remat_call_p or
not any(is_linear(v) for v in it.chain(eqn.invars, bound_vars))):
not any(is_linear(v) for v in eqn.invars)):
ans = _eval_subjaxpr_primals(
eqn.primitive, subjaxpr, map(read_primal, const_vars),
map(read_primal, bound_vars), map(read_primal, eqn.invars), eqn.params)
map(read_primal, eqn.invars), eqn.params)
map(write_primal, eqn.outvars, ans)
return map(read_primal, jaxpr.outvars)

Expand Down Expand Up @@ -471,7 +466,7 @@ def fun_lin_transpose(cts, *args, **kwargs):
num_res, trans_jaxpr = kwargs['num_res'], kwargs['trans_jaxpr']
res, _ = split_list(args, [num_res])
cts = map(instantiate_zeros_aval, kwargs['out_avals'], cts)
outs = core.eval_jaxpr(trans_jaxpr, res, (), *cts)
outs = core.eval_jaxpr(trans_jaxpr, res, *cts)
return [None] * num_res + outs
primitive_transposes[fun_lin_p] = fun_lin_transpose

Expand Down Expand Up @@ -544,8 +539,8 @@ def traceable(num_primals, in_tree_def, *primals_and_tangents):
yield out_flat, tree_def


def call_transpose(primitive, params, jaxpr, consts, freevar_vals, args, ct):
all_args, in_tree_def = tree_flatten((consts, freevar_vals, args, ct))
def call_transpose(primitive, params, jaxpr, consts, args, ct):
all_args, in_tree_def = tree_flatten((consts, args, ct))
fun = lu.hashable_partial(lu.wrap_init(backward_pass), jaxpr)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
params = dict(params, name=wrap_name(params['name'], 'transpose'))
Expand All @@ -554,15 +549,22 @@ def call_transpose(primitive, params, jaxpr, consts, freevar_vals, args, ct):
primitive_transposes[core.call_p] = partial(call_transpose, call_p)
primitive_transposes[pe.remat_call_p] = partial(call_transpose, pe.remat_call_p)

def map_transpose(primitive, params, jaxpr, consts, freevar_vals, args, ct):
all_args, in_tree_def = tree_flatten((consts, freevar_vals, args, ct))
def map_transpose(primitive, params, jaxpr, consts, args, ct):
all_args, in_tree_def = tree_flatten((consts, args, ct))
fun = lu.hashable_partial(lu.wrap_init(backward_pass), jaxpr)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
params = dict(params, name=wrap_name(params['name'], 'transpose'))
out_flat = primitive.bind(fun, *all_args, **params)
freevar_cts, arg_cts = tree_unflatten(out_tree(), out_flat)
freevar_cts = [x.sum(0) if x is not zero else x for x in freevar_cts]
return freevar_cts, arg_cts
arg_cts = tree_unflatten(out_tree(), out_flat)

mapped_invars = params['mapped_invars'] # True for each mapped invar
# The freevars are being fanned out (not mapped). During transpose the
# dual of fan-out is fan-in-sum. We apply it to the unmapped invars.
assert len(mapped_invars) == len(arg_cts)
arg_cts = (arg_ct if arg_mapped or arg_ct is zero else arg_ct.sum(0)
for arg_ct, arg_mapped in zip(arg_cts, mapped_invars))

return arg_cts


def jvp_jaxpr(jaxpr, nonzeros, instantiate):
Expand All @@ -588,10 +590,10 @@ def f_jvp_traceable(nonzeros, *primals_and_nztangents):
nonzero_tangents_out = [t for t in tangents_out if t is not zero]
yield list(primals_out) + nonzero_tangents_out, out_nonzeros

def rearrange_binders(jaxpr, primals_in, tangents_in, primals_out, tangents_out):
def rearrange_binders(jaxpr: core.TypedJaxpr, primals_in, tangents_in, primals_out, tangents_out):
new_invars = _perm(primals_in, tangents_in, jaxpr.jaxpr.invars)
new_outvars = _perm(primals_out, tangents_out, jaxpr.jaxpr.outvars)
new_jaxpr = core.Jaxpr(jaxpr.jaxpr.constvars, jaxpr.jaxpr.freevars,
new_jaxpr = core.Jaxpr(jaxpr.jaxpr.constvars,
new_invars, new_outvars, jaxpr.jaxpr.eqns)
new_in_avals = _perm(primals_in, tangents_in, jaxpr.in_avals)
new_out_avals = _perm(primals_out, tangents_out, jaxpr.out_avals)
Expand Down
Loading