Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
c19c3e9
Allow shapecheck of indexing, slicing, broadcast_to, reshape, random.…
juliuskunze Jan 16, 2020
7e424c9
Fix dynamic slicing
juliuskunze Jan 17, 2020
7595316
Fix issue with float64.__index__()
juliuskunze Jan 17, 2020
eaa847d
Fix np.arange with float size, _try_canonicalize_shape
juliuskunze Jan 17, 2020
4abe17d
Cleanup: Make methods to create Poly internal (only use in Poly / sha…
juliuskunze Jan 17, 2020
79c1a7b
Fix testReshapeWithUnusualShapes (error message)
juliuskunze Jan 17, 2020
a45b879
Fix syntax for python 3.6
juliuskunze Jan 17, 2020
5d630e9
Remove Poly.__index__
juliuskunze Jan 20, 2020
212f241
Merge branch 'master' of https://github.com/google/jax into shapechec…
juliuskunze Jan 20, 2020
99455b5
Fix tests
juliuskunze Jan 20, 2020
293caf6
Split up masking.py
juliuskunze Jan 20, 2020
3283bba
Cleanup masking
juliuskunze Jan 20, 2020
97c2506
Cleanup
juliuskunze Jan 20, 2020
4433fb3
Use abstract_eval for shapecheck, remove ShapeCheckTrace(r)
juliuskunze Jan 20, 2020
1d28237
Remove shape_rules, fix test
juliuskunze Jan 20, 2020
189e586
Remove shapes.py, move code to abstract_arrays.py / api.py
juliuskunze Jan 20, 2020
8a6f1fc
Remove safe_map/zip, is_instance from abstract_arrays, test + fix Pol…
juliuskunze Jan 20, 2020
72d8021
Add missing shapecheck_test.py
juliuskunze Jan 20, 2020
f6ba303
Cleanup, minimize changes
juliuskunze Jan 20, 2020
115581e
Minimize import diff
juliuskunze Jan 20, 2020
4e9e237
Minor
juliuskunze Jan 20, 2020
847d437
Allow shapecheck of np.where
juliuskunze Jan 21, 2020
cfb19fa
Fix np.where
juliuskunze Jan 21, 2020
732315b
Simplify gather to allow retightening type assertion in ConcreteArray
juliuskunze Jan 21, 2020
cb4696a
Remove unused imports
juliuskunze Jan 21, 2020
a760fa1
Make import style consistent
juliuskunze Jan 21, 2020
987d65c
Remove is_polymorphic, special cases in sampling, split, where.
juliuskunze Jan 21, 2020
c3c4588
Move back Poly, _parse_shape_spec into masking.py to simplify diff
juliuskunze Jan 21, 2020
bc4cdfa
Move back ShapeTest into masking_test.py to simplify diff
juliuskunze Jan 21, 2020
01404ed
Minor reverts to further simplify diff
juliuskunze Jan 21, 2020
026f933
Fix tests
juliuskunze Jan 21, 2020
5bb0730
Merge branch 'master' of https://github.com/google/jax into shapechec…
juliuskunze Jan 29, 2020
4ff7c3c
Merge remote-tracking branch 'main/master' into shapecheck-pcnn
juliuskunze Jan 30, 2020
aca2724
Minimize diff
juliuskunze Jan 30, 2020
d4a8bbb
Restore copyright, cleanup imports in masking.py
juliuskunze Jan 31, 2020
1d12235
Merge branch 'master' of https://github.com/google/jax into shapechec…
juliuskunze Feb 6, 2020
dd4b8d2
Merge branch 'master' of https://github.com/google/jax into shapechec…
juliuskunze Feb 7, 2020
02f1589
Merge branch 'master' of https://github.com/google/jax into shapechec…
juliuskunze Feb 7, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 11 additions & 15 deletions jax/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
import collections
import functools
import itertools as it
import operator as op
import os
import string
import threading
from warnings import warn

Expand All @@ -51,14 +51,14 @@
from .lib.xla_bridge import (device_count, local_device_count, devices, local_devices,
host_id, host_ids, host_count)
from .abstract_arrays import ConcreteArray, ShapedArray, raise_to_shaped
from .interpreters.masking import eval_polymorphic_shape, Poly, Mon
from .interpreters import partial_eval as pe
from .interpreters import xla
from .interpreters import pxla
from .interpreters import ad
from .interpreters import batching
from .interpreters import parallel
from .interpreters import masking
from .interpreters.masking import shapecheck, ensure_poly
from .config import flags, config, bool_env

map = safe_map
Expand Down Expand Up @@ -1038,24 +1038,23 @@ def wrapped_fun(args, logical_env):
out_shapes = map(masking.finalize_spec, out_specs, map(onp.shape, outs))
if not out_shapes == list(out_shapes_):
raise masking.ShapeError
if not all(onp.shape(out) == masking.eval_shape_expr(padded_env, expr)
for out, expr in zip(outs, out_shapes)):
if not all(onp.shape(out) == eval_polymorphic_shape(shape, padded_env)
for out, shape in zip(outs, out_shapes)):
raise masking.ShapeError
return tree_unflatten(out_tree(), outs)
return wrapped_fun

def _remap_ids(names, shape_spec):
ShapeSpec, Poly, Mon = masking.ShapeSpec, masking.Poly, masking.Mon
mdim = masking.monomorphic_dim
return ShapeSpec(Poly({Mon({names[id] : deg for id, deg in mon.items()})
return masking.ShapeSpec(Poly({Mon({names[id] : deg for id, deg in mon.items()})
: coeff for mon, coeff in poly.items()})
if poly is not mdim else mdim for poly in shape_spec)
if poly is not masking._monomorphic_dim else
masking._monomorphic_dim for poly in shape_spec)

def _bind_shapes(shape_exprs, shapes):
env = {}
for shape_expr, shape in zip(shape_exprs, shapes):
for poly, d in zip(shape_expr, shape):
if ensure_poly(poly).is_constant:
if type(poly) is not Poly or poly.is_constant:
continue
else:
(binder,), = poly # TODO generalize to handle striding
Expand All @@ -1070,16 +1069,13 @@ def shapecheck(in_shapes, out_shape, fun):
out_shapes, out_tree = tree_flatten(out_shape)
out_shapes = map(masking.parse_spec, out_shapes)
flat_fun, out_tree_ = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
out_shapes_ = masking.shapecheck(flat_fun, in_shapes)
avals = map(partial(ShapedArray, dtype=onp.float32), in_shapes)
out_shapes_ = [o.shape for o in pe.abstract_eval_fun(flat_fun.call_wrapped, *avals)]
if out_tree != out_tree_(): raise TypeError("pytree mismatch")
if not all(map(_shape_spec_consistent, out_shapes, out_shapes_)):
if not all(map(masking._shape_spec_consistent, out_shapes, out_shapes_)):
raise masking.ShapeError
return fun

def _shape_spec_consistent(spec, expr):
return all(a == b for a, b in zip(spec, expr) if a is not masking.monomorphic_dim)


def jvp(fun, primals, tangents):
"""Computes a (forward-mode) Jacobian-vector product of `fun`.

Expand Down
Loading