@@ -110,7 +110,7 @@ def vjp_(*cts):
110
110
cts = tuple (map (ignore_consts , cts , pvals ))
111
111
dummy_primals_and_cts = (core .unit ,) * len (cts ) + cts
112
112
dummy_args = (undefined_primal ,) * len (jaxpr .invars )
113
- _ , arg_cts = backward_pass (jaxpr , consts , () , dummy_args , dummy_primals_and_cts )
113
+ arg_cts = backward_pass (jaxpr , consts , dummy_args , dummy_primals_and_cts )
114
114
arg_cts = arg_cts [len (primals ):]
115
115
return map (instantiate_zeros , primals , arg_cts )
116
116
@@ -137,9 +137,9 @@ def unpair_pval(pval):
137
137
aval_1 , aval_2 = aval
138
138
return (aval_1 , const_1 ), (aval_2 , const_2 )
139
139
140
- def backward_pass (jaxpr , consts , freevar_vals , args , cotangents_in ):
140
+ def backward_pass (jaxpr : core . Jaxpr , consts , args , cotangents_in ):
141
141
if all (ct is zero for ct in cotangents_in ):
142
- return [zero ] * len (jaxpr .freevars ), [ zero ] * len ( jaxpr . invars )
142
+ return [zero ] * len (jaxpr .invars )
143
143
144
144
def write_cotangent (v , ct ):
145
145
# assert v not in primal_env
@@ -162,7 +162,6 @@ def write_primal(v, val):
162
162
primal_env = {}
163
163
write_primal (core .unitvar , core .unit )
164
164
map (write_primal , jaxpr .constvars , consts )
165
- map (write_primal , jaxpr .freevars , freevar_vals )
166
165
map (write_primal , jaxpr .invars , args )
167
166
168
167
def is_linear (var ):
@@ -184,22 +183,22 @@ def is_linear(var):
184
183
else :
185
184
write_primal (eqn .outvars [0 ], ans )
186
185
else :
187
- (subjaxpr , const_vars , bound_vars ), = eqn .bound_subjaxprs
186
+ (subjaxpr , const_vars ), = eqn .bound_subjaxprs
188
187
assert not any (is_linear (v ) for v in const_vars )
189
- if any (is_linear (v ) for v in it . chain ( eqn .invars , bound_vars ) ):
188
+ if any (is_linear (v ) for v in eqn .invars ):
190
189
linear_eqns .append (eqn )
191
190
elif eqn .primitive is not pe .remat_call_p :
192
191
ans = _eval_subjaxpr_primals (
193
192
eqn .primitive , subjaxpr , map (read_primal , const_vars ),
194
- map (read_primal , bound_vars ), map ( read_primal , eqn .invars ), eqn .params )
193
+ map (read_primal , eqn .invars ), eqn .params )
195
194
map (write_primal , eqn .outvars , ans )
196
195
197
196
# we special-case remat_call here because it can be mixed linear /
198
197
# nonlinear, so we always evaluate it even if it has a linear part
199
198
if eqn .primitive is pe .remat_call_p :
200
199
ans = _eval_subjaxpr_primals (
201
200
eqn .primitive , subjaxpr , map (read_primal , const_vars ),
202
- map (read_primal , bound_vars ), map ( read_primal , eqn .invars ), eqn .params )
201
+ map (read_primal , eqn .invars ), eqn .params )
203
202
map (write_primal , eqn .outvars , ans )
204
203
205
204
ct_env = {}
@@ -211,29 +210,26 @@ def is_linear(var):
211
210
else :
212
211
cts_in , = map (read_cotangent , eqn .outvars )
213
212
if eqn .bound_subjaxprs :
214
- (subjaxpr , const_vars , bound_vars ), = eqn .bound_subjaxprs
213
+ (subjaxpr , const_vars ), = eqn .bound_subjaxprs
215
214
sub_consts = map (read_primal , const_vars )
216
- sub_freevar_vals = map (read_primal , bound_vars )
217
- ct_free_vars_out , cts_out = get_primitive_transpose (eqn .primitive )(
218
- eqn .params , subjaxpr , sub_consts , sub_freevar_vals , invals , cts_in )
219
- map (write_cotangent , bound_vars , ct_free_vars_out )
215
+ cts_out = get_primitive_transpose (eqn .primitive )(
216
+ eqn .params , subjaxpr , sub_consts , invals , cts_in )
220
217
else :
221
218
cts_out = get_primitive_transpose (eqn .primitive )(cts_in , * invals , ** eqn .params )
222
219
cts_out = [zero ] * len (eqn .invars ) if cts_out is zero else cts_out
223
220
map (write_cotangent , eqn .invars , cts_out )
224
221
225
- freevar_cts = map (read_cotangent , jaxpr .freevars )
226
222
cotangents_out = map (read_cotangent , jaxpr .invars )
227
- return freevar_cts , cotangents_out
223
+ return cotangents_out
228
224
229
- def _eval_subjaxpr_primals (prim , jaxpr , consts , freevar_vals , in_vals , params ):
230
- all_args , in_tree_def = tree_flatten ((consts , freevar_vals , in_vals ))
225
+ def _eval_subjaxpr_primals (prim , jaxpr , consts , in_vals , params ):
226
+ all_args , in_tree_def = tree_flatten ((consts , in_vals ))
231
227
fun = lu .hashable_partial (lu .wrap_init (_eval_primals ), jaxpr )
232
228
fun , out_tree = flatten_fun_nokwargs (fun , in_tree_def )
233
229
out_flat = prim .bind (fun , * all_args , ** params )
234
230
return tree_unflatten (out_tree (), out_flat )
235
231
236
- def _eval_primals (jaxpr , consts , freevar_vals , args ):
232
+ def _eval_primals (jaxpr , consts , args ):
237
233
primal_env = {}
238
234
239
235
def read_primal (v ):
@@ -254,7 +250,6 @@ def is_linear(var):
254
250
255
251
write_primal (core .unitvar , core .unit )
256
252
map (write_primal , jaxpr .constvars , consts )
257
- map (write_primal , jaxpr .freevars , freevar_vals )
258
253
map (write_primal , jaxpr .invars , args )
259
254
for eqn in jaxpr .eqns :
260
255
if not eqn .bound_subjaxprs :
@@ -266,13 +261,13 @@ def is_linear(var):
266
261
else :
267
262
write_primal (eqn .outvars [0 ], ans )
268
263
else :
269
- (subjaxpr , const_vars , bound_vars ), = eqn .bound_subjaxprs
264
+ (subjaxpr , const_vars ), = eqn .bound_subjaxprs
270
265
assert not any (is_linear (v ) for v in const_vars )
271
266
if (eqn .primitive is pe .remat_call_p or
272
- not any (is_linear (v ) for v in it . chain ( eqn .invars , bound_vars ) )):
267
+ not any (is_linear (v ) for v in eqn .invars )):
273
268
ans = _eval_subjaxpr_primals (
274
269
eqn .primitive , subjaxpr , map (read_primal , const_vars ),
275
- map (read_primal , bound_vars ), map ( read_primal , eqn .invars ), eqn .params )
270
+ map (read_primal , eqn .invars ), eqn .params )
276
271
map (write_primal , eqn .outvars , ans )
277
272
return map (read_primal , jaxpr .outvars )
278
273
@@ -471,7 +466,7 @@ def fun_lin_transpose(cts, *args, **kwargs):
471
466
num_res , trans_jaxpr = kwargs ['num_res' ], kwargs ['trans_jaxpr' ]
472
467
res , _ = split_list (args , [num_res ])
473
468
cts = map (instantiate_zeros_aval , kwargs ['out_avals' ], cts )
474
- outs = core .eval_jaxpr (trans_jaxpr , res , (), * cts )
469
+ outs = core .eval_jaxpr (trans_jaxpr , res , * cts )
475
470
return [None ] * num_res + outs
476
471
primitive_transposes [fun_lin_p ] = fun_lin_transpose
477
472
@@ -544,8 +539,8 @@ def traceable(num_primals, in_tree_def, *primals_and_tangents):
544
539
yield out_flat , tree_def
545
540
546
541
547
- def call_transpose (primitive , params , jaxpr , consts , freevar_vals , args , ct ):
548
- all_args , in_tree_def = tree_flatten ((consts , freevar_vals , args , ct ))
542
+ def call_transpose (primitive , params , jaxpr , consts , args , ct ):
543
+ all_args , in_tree_def = tree_flatten ((consts , args , ct ))
549
544
fun = lu .hashable_partial (lu .wrap_init (backward_pass ), jaxpr )
550
545
fun , out_tree = flatten_fun_nokwargs (fun , in_tree_def )
551
546
params = dict (params , name = wrap_name (params ['name' ], 'transpose' ))
@@ -554,15 +549,22 @@ def call_transpose(primitive, params, jaxpr, consts, freevar_vals, args, ct):
554
549
primitive_transposes [core .call_p ] = partial (call_transpose , call_p )
555
550
primitive_transposes [pe .remat_call_p ] = partial (call_transpose , pe .remat_call_p )
556
551
557
- def map_transpose (primitive , params , jaxpr , consts , freevar_vals , args , ct ):
558
- all_args , in_tree_def = tree_flatten ((consts , freevar_vals , args , ct ))
552
+ def map_transpose (primitive , params , jaxpr , consts , args , ct ):
553
+ all_args , in_tree_def = tree_flatten ((consts , args , ct ))
559
554
fun = lu .hashable_partial (lu .wrap_init (backward_pass ), jaxpr )
560
555
fun , out_tree = flatten_fun_nokwargs (fun , in_tree_def )
561
556
params = dict (params , name = wrap_name (params ['name' ], 'transpose' ))
562
557
out_flat = primitive .bind (fun , * all_args , ** params )
563
- freevar_cts , arg_cts = tree_unflatten (out_tree (), out_flat )
564
- freevar_cts = [x .sum (0 ) if x is not zero else x for x in freevar_cts ]
565
- return freevar_cts , arg_cts
558
+ arg_cts = tree_unflatten (out_tree (), out_flat )
559
+
560
+ mapped_invars = params ['mapped_invars' ] # True for each mapped invar
561
+ # The freevars are being fanned out (not mapped). During transpose the
562
+ # dual of fan-out is fan-in-sum. We apply it to the unmapped invars.
563
+ assert len (mapped_invars ) == len (arg_cts )
564
+ arg_cts = (arg_ct if arg_mapped or arg_ct is zero else arg_ct .sum (0 )
565
+ for arg_ct , arg_mapped in zip (arg_cts , mapped_invars ))
566
+
567
+ return arg_cts
566
568
567
569
568
570
def jvp_jaxpr (jaxpr , nonzeros , instantiate ):
@@ -588,10 +590,10 @@ def f_jvp_traceable(nonzeros, *primals_and_nztangents):
588
590
nonzero_tangents_out = [t for t in tangents_out if t is not zero ]
589
591
yield list (primals_out ) + nonzero_tangents_out , out_nonzeros
590
592
591
- def rearrange_binders (jaxpr , primals_in , tangents_in , primals_out , tangents_out ):
593
+ def rearrange_binders (jaxpr : core . TypedJaxpr , primals_in , tangents_in , primals_out , tangents_out ):
592
594
new_invars = _perm (primals_in , tangents_in , jaxpr .jaxpr .invars )
593
595
new_outvars = _perm (primals_out , tangents_out , jaxpr .jaxpr .outvars )
594
- new_jaxpr = core .Jaxpr (jaxpr .jaxpr .constvars , jaxpr . jaxpr . freevars ,
596
+ new_jaxpr = core .Jaxpr (jaxpr .jaxpr .constvars ,
595
597
new_invars , new_outvars , jaxpr .jaxpr .eqns )
596
598
new_in_avals = _perm (primals_in , tangents_in , jaxpr .in_avals )
597
599
new_out_avals = _perm (primals_out , tangents_out , jaxpr .out_avals )
0 commit comments