27
27
import collections
28
28
import functools
29
29
import itertools as it
30
- import operator as op
31
30
import os
31
+ import string
32
32
import threading
33
33
from warnings import warn
34
34
51
51
from .lib .xla_bridge import (device_count , local_device_count , devices , local_devices ,
52
52
host_id , host_ids , host_count )
53
53
from .abstract_arrays import ConcreteArray , ShapedArray , raise_to_shaped
54
+ from .interpreters .masking import eval_polymorphic_shape , Poly , Mon
54
55
from .interpreters import partial_eval as pe
55
56
from .interpreters import xla
56
57
from .interpreters import pxla
57
58
from .interpreters import ad
58
59
from .interpreters import batching
59
60
from .interpreters import parallel
60
61
from .interpreters import masking
61
- from .interpreters .masking import shapecheck , ensure_poly
62
62
from .config import flags , config , bool_env
63
63
64
64
map = safe_map
@@ -1038,24 +1038,23 @@ def wrapped_fun(args, logical_env):
1038
1038
out_shapes = map (masking .finalize_spec , out_specs , map (onp .shape , outs ))
1039
1039
if not out_shapes == list (out_shapes_ ):
1040
1040
raise masking .ShapeError
1041
- if not all (onp .shape (out ) == masking . eval_shape_expr ( padded_env , expr )
1042
- for out , expr in zip (outs , out_shapes )):
1041
+ if not all (onp .shape (out ) == eval_polymorphic_shape ( shape , padded_env )
1042
+ for out , shape in zip (outs , out_shapes )):
1043
1043
raise masking .ShapeError
1044
1044
return tree_unflatten (out_tree (), outs )
1045
1045
return wrapped_fun
1046
1046
1047
1047
def _remap_ids (names , shape_spec ):
1048
- ShapeSpec , Poly , Mon = masking .ShapeSpec , masking .Poly , masking .Mon
1049
- mdim = masking .monomorphic_dim
1050
- return ShapeSpec (Poly ({Mon ({names [id ] : deg for id , deg in mon .items ()})
1048
+ return masking .ShapeSpec (Poly ({Mon ({names [id ] : deg for id , deg in mon .items ()})
1051
1049
: coeff for mon , coeff in poly .items ()})
1052
- if poly is not mdim else mdim for poly in shape_spec )
1050
+ if poly is not masking ._monomorphic_dim else
1051
+ masking ._monomorphic_dim for poly in shape_spec )
1053
1052
1054
1053
def _bind_shapes (shape_exprs , shapes ):
1055
1054
env = {}
1056
1055
for shape_expr , shape in zip (shape_exprs , shapes ):
1057
1056
for poly , d in zip (shape_expr , shape ):
1058
- if ensure_poly (poly ).is_constant :
1057
+ if type (poly ) is not Poly or poly .is_constant :
1059
1058
continue
1060
1059
else :
1061
1060
(binder ,), = poly # TODO generalize to handle striding
@@ -1070,16 +1069,13 @@ def shapecheck(in_shapes, out_shape, fun):
1070
1069
out_shapes , out_tree = tree_flatten (out_shape )
1071
1070
out_shapes = map (masking .parse_spec , out_shapes )
1072
1071
flat_fun , out_tree_ = flatten_fun_nokwargs (lu .wrap_init (fun ), in_tree )
1073
- out_shapes_ = masking .shapecheck (flat_fun , in_shapes )
1072
+ avals = map (partial (ShapedArray , dtype = onp .float32 ), in_shapes )
1073
+ out_shapes_ = [o .shape for o in pe .abstract_eval_fun (flat_fun .call_wrapped , * avals )]
1074
1074
if out_tree != out_tree_ (): raise TypeError ("pytree mismatch" )
1075
- if not all (map (_shape_spec_consistent , out_shapes , out_shapes_ )):
1075
+ if not all (map (masking . _shape_spec_consistent , out_shapes , out_shapes_ )):
1076
1076
raise masking .ShapeError
1077
1077
return fun
1078
1078
1079
- def _shape_spec_consistent (spec , expr ):
1080
- return all (a == b for a , b in zip (spec , expr ) if a is not masking .monomorphic_dim )
1081
-
1082
-
1083
1079
def jvp (fun , primals , tangents ):
1084
1080
"""Computes a (forward-mode) Jacobian-vector product of `fun`.
1085
1081
0 commit comments