Skip to content

Commit d01210e

Browse files
authored
Merge pull request #1959 from gnecula/no_freevars
An attempt to remove freevars from JAXPR.
2 parents ddc83e0 + 862a1d5 commit d01210e

File tree

12 files changed

+186
-140
lines changed

12 files changed

+186
-140
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ These are the release notes for JAX.
77
### Breaking changes
88

99
* The minimum jaxlib version is now 0.1.38.
10+
* Simplified `Jaxpr` by removing the `Jaxpr.freevars` and changing the
11+
representation of `Jaxpr.bound_subjaxprs` to drop the environment values.
1012

1113
### New features
1214

docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,6 @@
148148
"invars: [a]\n",
149149
"outvars: [b]\n",
150150
"constvars: []\n",
151-
"freevars: []\n",
152151
"equation: [a, 1] add [b] {}\n",
153152
"\n",
154153
"jaxpr: { lambda ; ; a.\n",
@@ -161,7 +160,6 @@
161160
"invars: [a, b, c]\n",
162161
"outvars: [g, c]\n",
163162
"constvars: [f]\n",
164-
"freevars: []\n",
165163
"equation: [a, c] dot_general [d] {'dimension_numbers': (((1,), (0,)), ((), ())), 'precision': None}\n",
166164
"equation: [d, b] add [e] {}\n",
167165
"equation: [e, f] add [g] {}\n",
@@ -182,7 +180,6 @@
182180
" print(\"invars:\", jaxpr.invars)\n",
183181
" print(\"outvars:\", jaxpr.outvars)\n",
184182
" print(\"constvars:\", jaxpr.constvars)\n",
185-
" print(\"freevars:\", jaxpr.freevars)\n",
186183
" for eqn in jaxpr.eqns:\n",
187184
" print(\"equation:\", eqn.invars, eqn.primitive, eqn.outvars, eqn.params)\n",
188185
" print()\n",
@@ -213,7 +210,6 @@
213210
"* `jaxpr.invars` - the `invars` of a Jaxpr are a list of the input variables to Jaxpr, analogous to arguments in Python functions\n",
214211
"* `jaxpr.outvars` - the `outvars` of a Jaxpr are the variables that are returned by the Jaxpr. Every Jaxpr has multiple outputs.\n",
215212
"* `jaxpr.constvars` - the `constvars` are a list of variables that are also inputs to the Jaxpr, but correspond to constants from the trace (we'll go over these in more detail later)\n",
216-
"* `jaxpr.freevars` - these can arise when nesting `jit` and `pmap` transformations; we won't worry about them in this colab.\n",
217213
"* `jaxpr.eqns` - a list of equations, which are essentially let-bindings. Each equation is list of input variables, a list of output variables, and a *primitive*, which is used to evaluate inputs to produce outputs. Each equation also has a `params`, a dictionary of parameters.\n",
218214
"\n",
219215
"All together, a Jaxpr encapsulates a simple program that can be evaluated with inputs to produce an output. We'll go over how exactly to do this later. The important thing to note now is that a Jaxpr is a data structure that can be manipulated and evaluated in whatever way we want."

jax/api.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ def computation_maker(*args, **kwargs):
305305
xla_consts = map(c.Constant, consts)
306306
xla_args = xla._xla_callable_args(c, avals, tuple_args)
307307
outs = xla.jaxpr_subcomp(
308-
c, jaxpr, backend, axis_env_, xla_consts, (),
308+
c, jaxpr, backend, axis_env_, xla_consts,
309309
extend_name_stack(wrap_name(fun_name, 'xla_computation')), *xla_args)
310310
return c.Build(c.Tuple(*outs))
311311
return computation_maker
@@ -1211,7 +1211,7 @@ def fun(*tangents):
12111211
"the original primal values.")
12121212
raise ValueError(msg)
12131213
dummy = (core.unit,) * len(tangents)
1214-
out = eval_jaxpr(jaxpr, consts, (), *(dummy + tangents))
1214+
out = eval_jaxpr(jaxpr, consts, *(dummy + tangents))
12151215
tangents_out = out[len(out)//2:]
12161216
return tuple(map(pe.merge_pvals, tangents_out, out_pvals))
12171217

@@ -1491,7 +1491,7 @@ def custom_transforms(fun):
14911491

14921492
def fun_impl(*args, **params):
14931493
consts, args = split_list(args, [params['num_consts']])
1494-
return core.eval_jaxpr(params['jaxpr'], consts, (), *args)
1494+
return core.eval_jaxpr(params['jaxpr'], consts, *args)
14951495
fun_p.def_impl(fun_impl)
14961496

14971497
def fun_jvp(primals, tangents, **params):
@@ -1918,7 +1918,6 @@ def jaxpr_to_graphviz(jaxpr, consts):
19181918
fragment = []
19191919

19201920
fragment.extend(map(invar_node, jaxpr.invars, jaxpr.invars))
1921-
fragment.extend(map(freevar_node, jaxpr.freevars, jaxpr.freevars))
19221921
fragment.extend(map(constant_node, jaxpr.constvars, consts))
19231922

19241923
for eqn in jaxpr.eqns:

jax/core.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,18 @@
3838
# -------------------- jaxprs --------------------
3939

4040
class Jaxpr(object):
41-
def __init__(self, constvars, freevars, invars, outvars, eqns):
41+
def __init__(self, constvars, invars, outvars, eqns):
42+
"""
43+
Params:
44+
constvars: list of variables introduced for constants (either literals
45+
in the Python program, or the result of constant folding during the
46+
generation of the Jaxpr). Array constants are replaced with such variables
47+
while scalar constants are kept inline.
48+
invars: list of input variables. Together, `constvars` and `invars` are
49+
the inputs to the Jaxpr.
50+
outvars: list of output variables.
51+
eqns: list of equations."""
4252
self.constvars = list(constvars)
43-
self.freevars = list(freevars)
4453
self.invars = list(invars)
4554
self.outvars = list(outvars)
4655
self.eqns = list(eqns)
@@ -56,7 +65,6 @@ def __init__(self, jaxpr, literals, in_avals, out_avals):
5665
assert len(in_avals) == len(jaxpr.invars)
5766
assert all(isinstance(aval, AbstractValue) for aval in in_avals)
5867
assert all(isinstance(aval, AbstractValue) for aval in out_avals)
59-
assert not jaxpr.freevars
6068

6169
self.jaxpr = jaxpr
6270
self.literals = list(literals)
@@ -73,7 +81,7 @@ def __str__(self):
7381

7482
@curry
7583
def jaxpr_as_fun(typed_jaxpr, *args):
76-
return eval_jaxpr(typed_jaxpr.jaxpr, typed_jaxpr.literals, (), *args)
84+
return eval_jaxpr(typed_jaxpr.jaxpr, typed_jaxpr.literals, *args)
7785

7886

7987
JaxprEqn = namedtuple('JaxprEqn', ['invars', 'outvars', 'primitive',
@@ -186,7 +194,7 @@ def abstract_eval(self, *args, **kwargs):
186194
# -------------------- lifting --------------------
187195

188196

189-
def eval_jaxpr(jaxpr, consts, freevar_vals, *args):
197+
def eval_jaxpr(jaxpr, consts, *args):
190198
def read(v):
191199
if type(v) is Literal:
192200
return v.val
@@ -200,12 +208,10 @@ def write(v, val):
200208
write(unitvar, unit)
201209
map(write, jaxpr.constvars, consts)
202210
map(write, jaxpr.invars, args)
203-
map(write, jaxpr.freevars, freevar_vals)
204211
for eqn in jaxpr.eqns:
205212
in_vals = map(read, eqn.invars)
206-
subfuns = [partial(eval_jaxpr, subjaxpr, map(read, const_bindings),
207-
map(read, freevar_bindings))
208-
for subjaxpr, const_bindings, freevar_bindings
213+
subfuns = [partial(eval_jaxpr, subjaxpr, map(read, const_bindings))
214+
for subjaxpr, const_bindings
209215
in eqn.bound_subjaxprs]
210216
subfuns = map(lu.wrap_init, subfuns)
211217
ans = eqn.primitive.bind(*(subfuns + in_vals), **eqn.params)
@@ -638,12 +644,10 @@ def write_env(env, v):
638644

639645
write(unitvar)
640646
map(write, jaxpr.constvars)
641-
map(write, jaxpr.freevars)
642647
map(write, jaxpr.invars)
643648
for eqn in jaxpr.eqns:
644649
map(read, eqn.invars)
645-
for subjaxpr, constvars, freevars in eqn.bound_subjaxprs:
646-
map(read, freevars)
650+
for subjaxpr, constvars in eqn.bound_subjaxprs:
647651
map(read, constvars)
648652
check_jaxpr(subjaxpr)
649653
map(write, eqn.outvars)
@@ -662,19 +666,17 @@ def pp_eqn(eqn):
662666
lhs = pp_vars(eqn.outvars)
663667
pp_subexpr = pp('')
664668
if eqn.bound_subjaxprs:
665-
for subjaxpr, const_vars, bound_vars in eqn.bound_subjaxprs:
669+
for subjaxpr, const_vars in eqn.bound_subjaxprs:
666670
pp_subexpr = pp_subexpr + (
667671
pp_jaxpr(subjaxpr).indent(2)
668-
>> pp(' [ {} ; {} ]'.format(pp_vars(const_vars),
669-
pp_vars(bound_vars))))
672+
>> pp(' [ {} ]'.format(pp_vars(const_vars))))
670673
return (pp('{} = '.format(lhs)) >>
671674
pp(eqn.primitive.name) >> pp_kv_pairs(sorted(eqn.params.items()))
672675
>> pp(' ') >> pp(pp_vars(eqn.invars))) + pp_subexpr
673676

674677
def pp_jaxpr(jaxpr):
675-
return (pp('{{ lambda {} ; {} ; {}.'.format(pp_vars(jaxpr.constvars),
676-
pp_vars(jaxpr.freevars),
677-
pp_vars(jaxpr.invars))) +
678+
return (pp('{{ lambda {} ; {}.'.format(pp_vars(jaxpr.constvars),
679+
pp_vars(jaxpr.invars))) +
678680
((pp('let ') >>
679681
vcat(map(pp_eqn, jaxpr.eqns))) +
680682
pp('in {} }}'.format(jaxpr.outvars))).indent(2))

jax/interpreters/ad.py

Lines changed: 34 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def vjp_(*cts):
110110
cts = tuple(map(ignore_consts, cts, pvals))
111111
dummy_primals_and_cts = (core.unit,) * len(cts) + cts
112112
dummy_args = (undefined_primal,) * len(jaxpr.invars)
113-
_, arg_cts = backward_pass(jaxpr, consts, (), dummy_args, dummy_primals_and_cts)
113+
arg_cts = backward_pass(jaxpr, consts, dummy_args, dummy_primals_and_cts)
114114
arg_cts = arg_cts[len(primals):]
115115
return map(instantiate_zeros, primals, arg_cts)
116116

@@ -137,9 +137,9 @@ def unpair_pval(pval):
137137
aval_1, aval_2 = aval
138138
return (aval_1, const_1), (aval_2, const_2)
139139

140-
def backward_pass(jaxpr, consts, freevar_vals, args, cotangents_in):
140+
def backward_pass(jaxpr: core.Jaxpr, consts, args, cotangents_in):
141141
if all(ct is zero for ct in cotangents_in):
142-
return [zero] * len(jaxpr.freevars), [zero] * len(jaxpr.invars)
142+
return [zero] * len(jaxpr.invars)
143143

144144
def write_cotangent(v, ct):
145145
# assert v not in primal_env
@@ -162,7 +162,6 @@ def write_primal(v, val):
162162
primal_env = {}
163163
write_primal(core.unitvar, core.unit)
164164
map(write_primal, jaxpr.constvars, consts)
165-
map(write_primal, jaxpr.freevars, freevar_vals)
166165
map(write_primal, jaxpr.invars, args)
167166

168167
def is_linear(var):
@@ -184,22 +183,22 @@ def is_linear(var):
184183
else:
185184
write_primal(eqn.outvars[0], ans)
186185
else:
187-
(subjaxpr, const_vars, bound_vars), = eqn.bound_subjaxprs
186+
(subjaxpr, const_vars), = eqn.bound_subjaxprs
188187
assert not any(is_linear(v) for v in const_vars)
189-
if any(is_linear(v) for v in it.chain(eqn.invars, bound_vars)):
188+
if any(is_linear(v) for v in eqn.invars):
190189
linear_eqns.append(eqn)
191190
elif eqn.primitive is not pe.remat_call_p:
192191
ans = _eval_subjaxpr_primals(
193192
eqn.primitive, subjaxpr, map(read_primal, const_vars),
194-
map(read_primal, bound_vars), map(read_primal, eqn.invars), eqn.params)
193+
map(read_primal, eqn.invars), eqn.params)
195194
map(write_primal, eqn.outvars, ans)
196195

197196
# we special-case remat_call here because it can be mixed linear /
198197
# nonlinear, so we always evaluate it even if it has a linear part
199198
if eqn.primitive is pe.remat_call_p:
200199
ans = _eval_subjaxpr_primals(
201200
eqn.primitive, subjaxpr, map(read_primal, const_vars),
202-
map(read_primal, bound_vars), map(read_primal, eqn.invars), eqn.params)
201+
map(read_primal, eqn.invars), eqn.params)
203202
map(write_primal, eqn.outvars, ans)
204203

205204
ct_env = {}
@@ -211,29 +210,26 @@ def is_linear(var):
211210
else:
212211
cts_in, = map(read_cotangent, eqn.outvars)
213212
if eqn.bound_subjaxprs:
214-
(subjaxpr, const_vars, bound_vars), = eqn.bound_subjaxprs
213+
(subjaxpr, const_vars), = eqn.bound_subjaxprs
215214
sub_consts = map(read_primal, const_vars)
216-
sub_freevar_vals = map(read_primal, bound_vars)
217-
ct_free_vars_out, cts_out = get_primitive_transpose(eqn.primitive)(
218-
eqn.params, subjaxpr, sub_consts, sub_freevar_vals, invals, cts_in)
219-
map(write_cotangent, bound_vars, ct_free_vars_out)
215+
cts_out = get_primitive_transpose(eqn.primitive)(
216+
eqn.params, subjaxpr, sub_consts, invals, cts_in)
220217
else:
221218
cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals, **eqn.params)
222219
cts_out = [zero] * len(eqn.invars) if cts_out is zero else cts_out
223220
map(write_cotangent, eqn.invars, cts_out)
224221

225-
freevar_cts = map(read_cotangent, jaxpr.freevars)
226222
cotangents_out = map(read_cotangent, jaxpr.invars)
227-
return freevar_cts, cotangents_out
223+
return cotangents_out
228224

229-
def _eval_subjaxpr_primals(prim, jaxpr, consts, freevar_vals, in_vals, params):
230-
all_args, in_tree_def = tree_flatten((consts, freevar_vals, in_vals))
225+
def _eval_subjaxpr_primals(prim, jaxpr, consts, in_vals, params):
226+
all_args, in_tree_def = tree_flatten((consts, in_vals))
231227
fun = lu.hashable_partial(lu.wrap_init(_eval_primals), jaxpr)
232228
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
233229
out_flat = prim.bind(fun, *all_args, **params)
234230
return tree_unflatten(out_tree(), out_flat)
235231

236-
def _eval_primals(jaxpr, consts, freevar_vals, args):
232+
def _eval_primals(jaxpr, consts, args):
237233
primal_env = {}
238234

239235
def read_primal(v):
@@ -254,7 +250,6 @@ def is_linear(var):
254250

255251
write_primal(core.unitvar, core.unit)
256252
map(write_primal, jaxpr.constvars, consts)
257-
map(write_primal, jaxpr.freevars, freevar_vals)
258253
map(write_primal, jaxpr.invars, args)
259254
for eqn in jaxpr.eqns:
260255
if not eqn.bound_subjaxprs:
@@ -266,13 +261,13 @@ def is_linear(var):
266261
else:
267262
write_primal(eqn.outvars[0], ans)
268263
else:
269-
(subjaxpr, const_vars, bound_vars), = eqn.bound_subjaxprs
264+
(subjaxpr, const_vars), = eqn.bound_subjaxprs
270265
assert not any(is_linear(v) for v in const_vars)
271266
if (eqn.primitive is pe.remat_call_p or
272-
not any(is_linear(v) for v in it.chain(eqn.invars, bound_vars))):
267+
not any(is_linear(v) for v in eqn.invars)):
273268
ans = _eval_subjaxpr_primals(
274269
eqn.primitive, subjaxpr, map(read_primal, const_vars),
275-
map(read_primal, bound_vars), map(read_primal, eqn.invars), eqn.params)
270+
map(read_primal, eqn.invars), eqn.params)
276271
map(write_primal, eqn.outvars, ans)
277272
return map(read_primal, jaxpr.outvars)
278273

@@ -471,7 +466,7 @@ def fun_lin_transpose(cts, *args, **kwargs):
471466
num_res, trans_jaxpr = kwargs['num_res'], kwargs['trans_jaxpr']
472467
res, _ = split_list(args, [num_res])
473468
cts = map(instantiate_zeros_aval, kwargs['out_avals'], cts)
474-
outs = core.eval_jaxpr(trans_jaxpr, res, (), *cts)
469+
outs = core.eval_jaxpr(trans_jaxpr, res, *cts)
475470
return [None] * num_res + outs
476471
primitive_transposes[fun_lin_p] = fun_lin_transpose
477472

@@ -544,8 +539,8 @@ def traceable(num_primals, in_tree_def, *primals_and_tangents):
544539
yield out_flat, tree_def
545540

546541

547-
def call_transpose(primitive, params, jaxpr, consts, freevar_vals, args, ct):
548-
all_args, in_tree_def = tree_flatten((consts, freevar_vals, args, ct))
542+
def call_transpose(primitive, params, jaxpr, consts, args, ct):
543+
all_args, in_tree_def = tree_flatten((consts, args, ct))
549544
fun = lu.hashable_partial(lu.wrap_init(backward_pass), jaxpr)
550545
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
551546
params = dict(params, name=wrap_name(params['name'], 'transpose'))
@@ -554,15 +549,22 @@ def call_transpose(primitive, params, jaxpr, consts, freevar_vals, args, ct):
554549
primitive_transposes[core.call_p] = partial(call_transpose, call_p)
555550
primitive_transposes[pe.remat_call_p] = partial(call_transpose, pe.remat_call_p)
556551

557-
def map_transpose(primitive, params, jaxpr, consts, freevar_vals, args, ct):
558-
all_args, in_tree_def = tree_flatten((consts, freevar_vals, args, ct))
552+
def map_transpose(primitive, params, jaxpr, consts, args, ct):
553+
all_args, in_tree_def = tree_flatten((consts, args, ct))
559554
fun = lu.hashable_partial(lu.wrap_init(backward_pass), jaxpr)
560555
fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
561556
params = dict(params, name=wrap_name(params['name'], 'transpose'))
562557
out_flat = primitive.bind(fun, *all_args, **params)
563-
freevar_cts, arg_cts = tree_unflatten(out_tree(), out_flat)
564-
freevar_cts = [x.sum(0) if x is not zero else x for x in freevar_cts]
565-
return freevar_cts, arg_cts
558+
arg_cts = tree_unflatten(out_tree(), out_flat)
559+
560+
mapped_invars = params['mapped_invars'] # True for each mapped invar
561+
# The freevars are being fanned out (not mapped). During transpose the
562+
# dual of fan-out is fan-in-sum. We apply it to the unmapped invars.
563+
assert len(mapped_invars) == len(arg_cts)
564+
arg_cts = (arg_ct if arg_mapped or arg_ct is zero else arg_ct.sum(0)
565+
for arg_ct, arg_mapped in zip(arg_cts, mapped_invars))
566+
567+
return arg_cts
566568

567569

568570
def jvp_jaxpr(jaxpr, nonzeros, instantiate):
@@ -588,10 +590,10 @@ def f_jvp_traceable(nonzeros, *primals_and_nztangents):
588590
nonzero_tangents_out = [t for t in tangents_out if t is not zero]
589591
yield list(primals_out) + nonzero_tangents_out, out_nonzeros
590592

591-
def rearrange_binders(jaxpr, primals_in, tangents_in, primals_out, tangents_out):
593+
def rearrange_binders(jaxpr: core.TypedJaxpr, primals_in, tangents_in, primals_out, tangents_out):
592594
new_invars = _perm(primals_in, tangents_in, jaxpr.jaxpr.invars)
593595
new_outvars = _perm(primals_out, tangents_out, jaxpr.jaxpr.outvars)
594-
new_jaxpr = core.Jaxpr(jaxpr.jaxpr.constvars, jaxpr.jaxpr.freevars,
596+
new_jaxpr = core.Jaxpr(jaxpr.jaxpr.constvars,
595597
new_invars, new_outvars, jaxpr.jaxpr.eqns)
596598
new_in_avals = _perm(primals_in, tangents_in, jaxpr.in_avals)
597599
new_out_avals = _perm(primals_out, tangents_out, jaxpr.out_avals)

0 commit comments

Comments
 (0)