Skip to content

Commit d6292b8

Browse files
committed
instantiate zeros
fix dtype
1 parent 49a8901 commit d6292b8

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

jax/experimental/jet.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from jax.util import unzip2
2323
from jax import ad_util
2424
from jax.tree_util import (register_pytree_node, tree_structure,
25-
treedef_is_leaf, tree_flatten, tree_unflatten)
25+
treedef_is_leaf, tree_flatten, tree_unflatten, tree_map)
2626
import jax.linear_util as lu
2727
from jax.interpreters import xla
2828
from jax.lax import lax
@@ -59,6 +59,8 @@ def jet_fun(primals, series):
5959
with core.new_master(JetTrace) as master:
6060
out_primals, out_terms = yield (master, primals, series), {}
6161
del master
62+
out_terms = [tree_map(lambda x: onp.zeros_like(x, dtype=onp.result_type(out_primals[0])), series[0])
63+
if s is zero_series else s for s in out_terms]
6264
yield out_primals, out_terms
6365

6466
@lu.transformation

tests/jet_test.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,21 @@ def test_select(self):
291291
series_in = (terms_b, terms_x, terms_y)
292292
self.check_jet(np.where, primals, series_in)
293293

294+
def test_inst_zero(self):
295+
def f(x):
296+
return 2.
297+
def g(x):
298+
return 2. + 0 * x
299+
x = np.ones(1)
300+
order = 3
301+
f_out_primals, f_out_series = jet(f, (x, ), ([np.ones_like(x) for _ in range(order)], ))
302+
assert f_out_series is not zero_series
303+
304+
g_out_primals, g_out_series = jet(g, (x, ), ([np.ones_like(x) for _ in range(order)], ))
305+
306+
assert g_out_primals == f_out_primals
307+
assert g_out_series == f_out_series
308+
294309

295310
if __name__ == '__main__':
296311
absltest.main()

0 commit comments

Comments
 (0)