Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
c19c3e9
Allow shapecheck of indexing, slicing, broadcast_to, reshape, random.…
juliuskunze Jan 16, 2020
7e424c9
Fix dynamic slicing
juliuskunze Jan 17, 2020
7595316
Fix issue with float64.__index__()
juliuskunze Jan 17, 2020
eaa847d
Fix np.arange with float size, _try_canonicalize_shape
juliuskunze Jan 17, 2020
4abe17d
Cleanup: Make methods to create Poly internal (only use in Poly / sha…
juliuskunze Jan 17, 2020
79c1a7b
Fix testReshapeWithUnusualShapes (error message)
juliuskunze Jan 17, 2020
a45b879
Fix syntax for python 3.6
juliuskunze Jan 17, 2020
5d630e9
Remove Poly.__index__
juliuskunze Jan 20, 2020
212f241
Merge branch 'master' of https://github.com/google/jax into shapechec…
juliuskunze Jan 20, 2020
99455b5
Fix tests
juliuskunze Jan 20, 2020
293caf6
Split up masking.py
juliuskunze Jan 20, 2020
3283bba
Cleanup masking
juliuskunze Jan 20, 2020
97c2506
Cleanup
juliuskunze Jan 20, 2020
4433fb3
Use abstract_eval for shapecheck, remove ShapeCheckTrace(r)
juliuskunze Jan 20, 2020
1d28237
Remove shape_rules, fix test
juliuskunze Jan 20, 2020
189e586
Remove shapes.py, move code to abstract_arrays.py / api.py
juliuskunze Jan 20, 2020
8a6f1fc
Remove safe_map/zip, is_instance from abstract_arrays, test + fix Pol…
juliuskunze Jan 20, 2020
72d8021
Add missing shapecheck_test.py
juliuskunze Jan 20, 2020
f6ba303
Cleanup, minimize changes
juliuskunze Jan 20, 2020
115581e
Minimize import diff
juliuskunze Jan 20, 2020
4e9e237
Minor
juliuskunze Jan 20, 2020
847d437
Allow shapecheck of np.where
juliuskunze Jan 21, 2020
cfb19fa
Fix np.where
juliuskunze Jan 21, 2020
732315b
Simplify gather to allow retightening type assertion in ConcreteArray
juliuskunze Jan 21, 2020
cb4696a
Remove unused imports
juliuskunze Jan 21, 2020
a760fa1
Make import style consistent
juliuskunze Jan 21, 2020
987d65c
Remove is_polymorphic, special cases in sampling, split, where.
juliuskunze Jan 21, 2020
c3c4588
Move back Poly, _parse_shape_spec into masking.py to simplify diff
juliuskunze Jan 21, 2020
bc4cdfa
Move back ShapeTest into masking_test.py to simplify diff
juliuskunze Jan 21, 2020
01404ed
Minor reverts to further simplify diff
juliuskunze Jan 21, 2020
026f933
Fix tests
juliuskunze Jan 21, 2020
5bb0730
Merge branch 'master' of https://github.com/google/jax into shapechec…
juliuskunze Jan 29, 2020
4ff7c3c
Merge remote-tracking branch 'main/master' into shapecheck-pcnn
juliuskunze Jan 30, 2020
aca2724
Minimize diff
juliuskunze Jan 30, 2020
d4a8bbb
Restore copyright, cleanup imports in masking.py
juliuskunze Jan 31, 2020
1d12235
Merge branch 'master' of https://github.com/google/jax into shapechec…
juliuskunze Feb 6, 2020
dd4b8d2
Merge branch 'master' of https://github.com/google/jax into shapechec…
juliuskunze Feb 7, 2020
02f1589
Merge branch 'master' of https://github.com/google/jax into shapechec…
juliuskunze Feb 7, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions jax/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import collections
import functools
import itertools as it
import operator as op
import os
import threading
from warnings import warn
Expand All @@ -52,8 +51,6 @@
from .util import (unzip2, unzip3, curry, partial, safe_map, safe_zip,
WrapHashably, Hashable, prod, split_list)
from .lib import xla_bridge as xb
from .lib.xla_bridge import (device_count, local_device_count, devices, local_devices,
host_id, host_ids, host_count)
from .abstract_arrays import ConcreteArray, ShapedArray, raise_to_shaped
from .interpreters import partial_eval as pe
from .interpreters import xla
Expand All @@ -62,7 +59,7 @@
from .interpreters import batching
from .interpreters import parallel
from .interpreters import masking
from .interpreters.masking import shapecheck, ensure_poly
from .interpreters.masking import Poly
from .config import flags, config

map = safe_map
Expand Down Expand Up @@ -1055,7 +1052,7 @@ def _bind_shapes(shape_exprs, shapes):
env = {}
for shape_expr, shape in zip(shape_exprs, shapes):
for poly, d in zip(shape_expr, shape):
if ensure_poly(poly).is_constant:
if type(poly) is not Poly or poly.is_constant:
continue
else:
(binder,), = poly # TODO generalize to handle striding
Expand Down
70 changes: 35 additions & 35 deletions jax/interpreters/masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,22 +52,17 @@ def extend_shape_envs(logical_env, padded_env):
yield
shape_envs = prev

# TODO remove remaining usages:
def is_polymorphic(shape):
return any(map(lambda d: isinstance(d, Poly), shape))
return any(map(lambda d: type(d) is Poly, shape))

def shape_as_value(expr):
if type(expr) is tuple and is_polymorphic(expr):
return tuple(eval_dim_expr(shape_envs.logical, d) if type(d) is Poly else d
for d in expr)
else:
return expr
return tuple(eval_dim_expr(shape_envs.logical, d) if type(d) is Poly else d
for d in expr)

def padded_shape_as_value(expr):
if type(expr) is tuple and is_polymorphic(expr):
return tuple(eval_dim_expr(shape_envs.padded, d) if type(d) is Poly else d
for d in expr)
else:
return expr
return tuple(eval_dim_expr(shape_envs.padded, d) if type(d) is Poly else d
for d in expr)


def mask_fun(fun, logical_env, padded_env, in_vals, shape_exprs):
Expand All @@ -89,11 +84,11 @@ def mask_subtrace(master, in_vals, shape_exprs):
yield out_vals, out_shapes


def ensure_poly(p):
if isinstance(p, Poly):
def _ensure_poly(p):
if type(p) is Poly:
return p

return constant_poly(int(p))
return _constant_poly(p)

class Poly(Counter):
"""Polynomial with integer coefficients,
Expand All @@ -111,7 +106,7 @@ def __init__(self, coeffs):
def __add__(self, other):
coeffs = self.copy()

for mon, coeff in ensure_poly(other).items():
for mon, coeff in _ensure_poly(other).items():
coeffs[mon] = coeffs.get(mon, 0) + coeff

return Poly(coeffs)
Expand All @@ -125,7 +120,7 @@ def __neg__(self):
def __mul__(self, other):
coeffs = dict()
for (mon1, coeff1), (mon2, coeff2) \
in it.product(self.items(), ensure_poly(other).items()):
in it.product(self.items(), _ensure_poly(other).items()):
mon = Mon(mon1 + mon2) # add monomials' id degrees
coeff = coeff1 * coeff2 # multiply integer coeffs
coeffs[mon] = coeffs.get(mon, 0) + coeff # accumulate coeffs
Expand All @@ -151,9 +146,7 @@ def __mod__(self, divisor):

def __divmod__(self, divisor):
if self.is_constant:
q, r = divmod(int(self), divisor)

return constant_poly(q), r
return divmod(int(self), divisor)

def divided(count):
q, r = divmod(count, divisor)
Expand All @@ -170,13 +163,13 @@ def __hash__(self):
return hash(super())

def __eq__(self, other):
return super().__eq__(ensure_poly(other))
return super().__eq__(_ensure_poly(other))

def __ne__(self, other):
return not self == other

def __ge__(self, other):
other = ensure_poly(other)
other = _ensure_poly(other)

if other.is_constant and self.is_constant:
return int(self) >= int(other)
Expand All @@ -195,13 +188,13 @@ def __ge__(self, other):
.format(self, other))

def __le__(self, other):
return ensure_poly(other) >= self
return _ensure_poly(other) >= self

def __lt__(self, other):
return not (self >= other)

def __gt__(self, other):
return not (ensure_poly(other) >= self)
return not (_ensure_poly(other) >= self)

def __str__(self):
return ' + '.join('{} {}'.format(v, k) if (v != 1 or k.degree == 0) else str(k)
Expand All @@ -212,6 +205,9 @@ def __int__(self):

return int(next(iter(self.values())))

def __index__(self):
return self

@property
def is_constant(self):
return len(self) == 1 and next(iter(self)).degree == 0
Expand Down Expand Up @@ -285,7 +281,7 @@ def __str__(self):
return 'ShapeSpec({})'.format(', '.join(map(str, self)))

def finalize_spec(spec, shape):
return tuple(parse_lit(d) if e is monomorphic_dim else e
return tuple(_parse_lit(d) if e is monomorphic_dim else e
for e, d in zip(spec, shape))

def parse_spec(spec=''):
Expand All @@ -294,30 +290,30 @@ def parse_spec(spec=''):
if spec[0] == '(':
if spec[-1] != ')': raise ShapeSyntaxError(spec)
spec = spec[1:-1]
dims = map(parse_dim, spec.replace(' ', '').strip(',').split(','))
dims = map(_parse_dim, spec.replace(' ', '').strip(',').split(','))
return ShapeSpec(dims)

def parse_dim(spec):
def _parse_dim(spec):
if '+' in spec:
terms = map(parse_dim, spec.split('+'))
terms = map(_parse_dim, spec.split('+'))
return functools.reduce(op.add, terms)
elif '*' in spec:
terms = map(parse_dim, spec.split('*'))
terms = map(_parse_dim, spec.split('*'))
return functools.reduce(op.mul, terms)
elif spec.isdigit() or spec.startswith('-') and spec[1:].isdigit():
return parse_lit(spec)
return _parse_lit(spec)
elif spec in identifiers:
return parse_id(spec)
return _parse_id(spec)
elif spec == '_':
return monomorphic_dim
else:
raise ShapeSyntaxError(spec)
digits = frozenset(string.digits)
identifiers = frozenset(string.ascii_lowercase)

def parse_id(name): return Poly({Mon({name: 1}): 1})
def parse_lit(val_str): return constant_poly(int(val_str))
def constant_poly(val): return Poly({Mon(): val})
def _parse_id(name): return Poly({Mon({name: 1}): 1})
def _parse_lit(val_str): return _constant_poly(int(val_str))
def _constant_poly(val): return Poly({Mon(): val.__index__()})

class MonomorphicDim(object):
def __str__(self): return '_'
Expand Down Expand Up @@ -352,7 +348,7 @@ def aval(self):
return ShapedArray(self.shape_expr, self.val.dtype)

def is_pure(self):
return all(ensure_poly(poly).is_constant for poly in self.shape_expr)
return all(_ensure_poly(poly).is_constant for poly in self.shape_expr)

def full_lower(self):
if self.is_pure():
Expand Down Expand Up @@ -456,7 +452,11 @@ def process_primitive(self, primitive, tracers, params):
if shape_rule is None:
raise NotImplementedError('Shape rule for {} not implemented yet.'.format(primitive))
out_shape = shape_rule(*avals, **params)
return ShapeCheckTracer(self, out_shape)

if primitive.multiple_results:
return map(partial(ShapeCheckTracer, self), out_shape)
else:
return ShapeCheckTracer(self, out_shape)

def process_call(self, call_primitive, f, tracers, params):
# TODO apply proper subtrace:
Expand Down
23 changes: 12 additions & 11 deletions jax/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@ def broadcast_shapes(*shapes):
.format(tuple(map(tuple, shapes))))
return tuple(result_shape)

def _try_canonicalize_shape(shape):
try:
return tuple(map(lambda x: x.__index__(), shape))
except (TypeError, AttributeError):
return None

def _canonicalize_shape(shape):
"""Canonicalizes and checks for errors in a user-provided shape value.

Expand All @@ -83,13 +89,10 @@ def _canonicalize_shape(shape):
Returns:
A tuple of integers.
"""
# TODO(mattjj): this next check is a temporary workaround for masking
if (type(shape) is tuple and masking.is_polymorphic(shape)):
return shape
try:
return tuple(map(operator.index, shape))
except TypeError:
pass
canonical = _try_canonicalize_shape(shape)
if canonical is not None:
return canonical

msg = ("Shapes must be 1D sequences of concrete values of integer type, "
"got {}.")
if any(isinstance(x, core.Tracer) and isinstance(core.get_aval(x), ShapedArray)
Expand Down Expand Up @@ -1082,7 +1085,7 @@ def iota(dtype, size):
<https://www.tensorflow.org/xla/operation_semantics#iota>`_
operator.
"""
size = int(size)
size = size.__index__()
dtype = dtypes.canonicalize_dtype(dtype)
lazy_expr = lazy.iota(dtype, size)
aval = ShapedArray((size,), dtype)
Expand Down Expand Up @@ -4354,8 +4357,6 @@ def conv_transpose_shape_tuple(lhs_shape, rhs_shape, window_strides, padding,

def _check_shapelike(fun_name, arg_name, obj):
"""Check that `obj` is a shape-like value (e.g. tuple of nonnegative ints)."""
if (type(obj) is tuple and masking.is_polymorphic(obj)):
return obj
if not isinstance(obj, (tuple, list, onp.ndarray)):
msg = "{} {} must be of type tuple/list/ndarray, got {}."
raise TypeError(msg.format(fun_name, arg_name, type(obj)))
Expand All @@ -4366,7 +4367,7 @@ def _check_shapelike(fun_name, arg_name, obj):
if obj_arr.ndim != 1:
msg = "{} {} must be rank 1, got {}."
raise TypeError(msg.format(obj_arr.ndim))
if not dtypes.issubdtype(obj_arr.dtype, onp.integer):
if _try_canonicalize_shape(obj_arr) is None:
msg = "{} {} must have every element be an integer type, got {}."
raise TypeError(msg.format(fun_name, arg_name, tuple(map(type, obj))))
if not (obj_arr >= 0).all():
Expand Down
Loading