Skip to content

Commit 0933538

Browse files
juliuskunzesrvasude
authored andcommitted
Allow shapecheck of PixelCNN++ (jax-ml#2017)
* Allow shapecheck of indexing, slicing, broadcast_to, reshape, random.uniform, iota, simple cases of split * Fix dynamic slicing * Fix issue with float64.__index__() * Fix np.arange with float size, _try_canonicalize_shape * Cleanup: Make methods to create Poly internal (only use in Poly / shape spec parsing) * Fix testReshapeWithUnusualShapes (error message) * Fix syntax for python 3.6 * Remove Poly.__index__ * Fix tests * Split up masking.py * Cleanup masking * Cleanup * Use abstract_eval for shapecheck, remove ShapeCheckTrace(r) * Remove shape_rules, fix test * Remove shapes.py, move code to abstract_arrays.py / api.py * Remove safe_map/zip, is_instance from abstract_arrays, test + fix Poly hash, minimize import diff * Add missing shapecheck_test.py * Cleanup, minimize changes * Minimize import diff * Minor * Allow shapecheck of np.where * Fix np.where * Simplify gather to allow retightening type assertion in ConcreteArray * Remove unused imports * Make import style consistent * Remove is_polymorphic, special cases in sampling, split, where. * Move back Poly, _parse_shape_spec into masking.py to simplify diff * Move back ShapeTest into masking_test.py to simplify diff * Minor reverts to further simplify diff * Fix tests * Minimize diff * Restore copyright, cleanup imports in masking.py * Merge branch 'master' of https://github.com/google/jax into shapecheck-pcnn # Conflicts: # jax/api.py # jax/numpy/lax_numpy.py
1 parent b727db5 commit 0933538

File tree

8 files changed

+224
-184
lines changed

8 files changed

+224
-184
lines changed

jax/api.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
import collections
2828
import functools
2929
import itertools as it
30-
import operator as op
3130
import os
31+
import string
3232
import threading
3333
from warnings import warn
3434

@@ -51,14 +51,14 @@
5151
from .lib.xla_bridge import (device_count, local_device_count, devices, local_devices,
5252
host_id, host_ids, host_count)
5353
from .abstract_arrays import ConcreteArray, ShapedArray, raise_to_shaped
54+
from .interpreters.masking import eval_polymorphic_shape, Poly, Mon
5455
from .interpreters import partial_eval as pe
5556
from .interpreters import xla
5657
from .interpreters import pxla
5758
from .interpreters import ad
5859
from .interpreters import batching
5960
from .interpreters import parallel
6061
from .interpreters import masking
61-
from .interpreters.masking import shapecheck, ensure_poly
6262
from .config import flags, config, bool_env
6363

6464
map = safe_map
@@ -1038,24 +1038,23 @@ def wrapped_fun(args, logical_env):
10381038
out_shapes = map(masking.finalize_spec, out_specs, map(onp.shape, outs))
10391039
if not out_shapes == list(out_shapes_):
10401040
raise masking.ShapeError
1041-
if not all(onp.shape(out) == masking.eval_shape_expr(padded_env, expr)
1042-
for out, expr in zip(outs, out_shapes)):
1041+
if not all(onp.shape(out) == eval_polymorphic_shape(shape, padded_env)
1042+
for out, shape in zip(outs, out_shapes)):
10431043
raise masking.ShapeError
10441044
return tree_unflatten(out_tree(), outs)
10451045
return wrapped_fun
10461046

10471047
def _remap_ids(names, shape_spec):
1048-
ShapeSpec, Poly, Mon = masking.ShapeSpec, masking.Poly, masking.Mon
1049-
mdim = masking.monomorphic_dim
1050-
return ShapeSpec(Poly({Mon({names[id] : deg for id, deg in mon.items()})
1048+
return masking.ShapeSpec(Poly({Mon({names[id] : deg for id, deg in mon.items()})
10511049
: coeff for mon, coeff in poly.items()})
1052-
if poly is not mdim else mdim for poly in shape_spec)
1050+
if poly is not masking._monomorphic_dim else
1051+
masking._monomorphic_dim for poly in shape_spec)
10531052

10541053
def _bind_shapes(shape_exprs, shapes):
10551054
env = {}
10561055
for shape_expr, shape in zip(shape_exprs, shapes):
10571056
for poly, d in zip(shape_expr, shape):
1058-
if ensure_poly(poly).is_constant:
1057+
if type(poly) is not Poly or poly.is_constant:
10591058
continue
10601059
else:
10611060
(binder,), = poly # TODO generalize to handle striding
@@ -1070,16 +1069,13 @@ def shapecheck(in_shapes, out_shape, fun):
10701069
out_shapes, out_tree = tree_flatten(out_shape)
10711070
out_shapes = map(masking.parse_spec, out_shapes)
10721071
flat_fun, out_tree_ = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
1073-
out_shapes_ = masking.shapecheck(flat_fun, in_shapes)
1072+
avals = map(partial(ShapedArray, dtype=onp.float32), in_shapes)
1073+
out_shapes_ = [o.shape for o in pe.abstract_eval_fun(flat_fun.call_wrapped, *avals)]
10741074
if out_tree != out_tree_(): raise TypeError("pytree mismatch")
1075-
if not all(map(_shape_spec_consistent, out_shapes, out_shapes_)):
1075+
if not all(map(masking._shape_spec_consistent, out_shapes, out_shapes_)):
10761076
raise masking.ShapeError
10771077
return fun
10781078

1079-
def _shape_spec_consistent(spec, expr):
1080-
return all(a == b for a, b in zip(spec, expr) if a is not masking.monomorphic_dim)
1081-
1082-
10831079
def jvp(fun, primals, tangents):
10841080
"""Computes a (forward-mode) Jacobian-vector product of `fun`.
10851081

0 commit comments

Comments
 (0)