@@ -52,10 +52,10 @@ public static Shape scalar() {
5252 /**
5353 * Create a Shape representing a scalar or an N-dimensional value.
5454 *
55- * <p>Creates a Shape representing a scalar or an N-dimensional value (N being at least 1),
56- * with the provided size for each dimension. A -1 indicates that the size of the corresponding
57- * dimension is unknown. If no sizes are provided, a Shape representing a scalar is created.
58- * For example:
55+ * <p>Creates a Shape representing a scalar or an N-dimensional value (N being at least 1), with
56+ * the provided size for each dimension. A -1 indicates that the size of the corresponding
57+ * dimension is unknown. If no sizes are provided, a Shape representing a scalar is created. For
58+ * example:
5959 *
6060 * <pre>{@code
6161 * // A 2-element vector.
@@ -88,11 +88,11 @@ public static Shape of(long... dimensionSizes) {
8888 /**
8989 * Returns the total number of elements a Tensor with this Shape would have.
9090 *
91- * <p>If {@link Shape#isUnknown()} is true or {@link Shape#hasUnknownDimension()} is true,
92- * {@link Shape#UNKNOWN_SIZE} is returned.
91+ * <p>If {@link Shape#isUnknown()} is true or {@link Shape#hasUnknownDimension()} is true, {@link
92+ * Shape#UNKNOWN_SIZE} is returned.
9393 *
9494 * @return The total number of elements a Tensor with this shape would have if it can be
95- * calculated, else {@link Shape#UNKNOWN_SIZE}.
95+ * calculated, else {@link Shape#UNKNOWN_SIZE}.
9696 */
9797 public long size () {
9898 if (size == null ) {
@@ -108,12 +108,11 @@ public long size() {
108108 * an unknown size, {@link Shape#UNKNOWN_SIZE} is returned.
109109 *
110110 * @param i the index of the dimension to get the size for. If this Shape has a known number of
111- * dimensions, it must be < {@link Shape#numDimensions()}. The index may be negative,
112- * in which case the position is counted from the end of the shape. E.g.:
113- * {@code size(-1)} returns the size of the last dimension, {@code size(-2)} the size of
114- * the second to last dimension etc.
111+ * dimensions, it must be < {@link Shape#numDimensions()}. The index may be negative, in which
112+ * case the position is counted from the end of the shape. E.g.: {@code size(-1)} returns the
113+ * size of the last dimension, {@code size(-2)} the size of the second to last dimension etc.
115114 * @return The size of the dimension with the given index if known, {@link Shape#UNKNOWN_SIZE}
116- * otherwise.
115+ * otherwise.
117116 */
118117 public long size (int i ) {
119118 if (dimensionSizes == null ) {
@@ -167,8 +166,8 @@ public boolean isUnknown() {
167166 }
168167
169168 /**
170- * Returns a defensive copy of the this Shape's axes. Changes to the returned array to not
171- * change this Shape's state. Returns null if {@link Shape#isUnknown()} is true.
169+ * Returns a defensive copy of the this Shape's axes. Changes to the returned array to not change
170+ * this Shape's state. Returns null if {@link Shape#isUnknown()} is true.
172171 */
173172 public long [] asArray () {
174173 if (this .dimensionSizes == null ) {
@@ -186,15 +185,16 @@ public int hashCode() {
186185 /**
187186 * Equals implementation for Shapes. Two Shapes are considered equal iff:
188187 *
188+ * <p>
189189 * <ul>
190- * <li>the number of dimensions is defined and equal for both
191- * <li>the size of each dimension is defined and equal for both
190+ * <li>the number of dimensions is defined and equal for both
191+ * <li>the size of each dimension is defined and equal for both
192192 * </ul>
193193 *
194194 * <p>If either Shape has unknown dimensions (even if they are the same in both) or if either
195- * shape has an unknown number of dimensions (even if both return {@code true} for
196- * {@link Shape#isUnknown()}), they are not considered equal! However, a shape will always
197- * equal itself, even if it is unknown or contains unknown dimensions.
195+ * shape has an unknown number of dimensions (even if both return {@code true} for {@link
196+ * Shape#isUnknown()}), they are not considered equal! However, a shape will always equal itself,
197+ * even if it is unknown or contains unknown dimensions.
198198 */
199199 @ Override
200200 public boolean equals (Object obj ) {
@@ -233,17 +233,17 @@ public Shape head() {
233233 }
234234
235235 /**
236- * Returns an n-dimensional Shape with the dimensions matching the first n dimensions
237- * of this shape
236+ * Returns an n-dimensional Shape with the dimensions matching the first n dimensions of this
237+ * shape
238238 *
239- * @param n the number of leading dimensions to get, must be < = than {@link Shape#numDimensions()}
240- * @return an n-dimensional Shape with the first n dimensions matching the first n dimensions
241- * of this Shape
239+ * @param n the number of leading dimensions to get, must be < = than {@link Shape#numDimensions()}
240+ * @return an n-dimensional Shape with the first n dimensions matching the first n dimensions of
241+ * this Shape
242242 */
243243 public Shape take (int n ) {
244244 if (n > numDimensions ()) {
245- throw new ArrayIndexOutOfBoundsException ("Cannot take " + n +
246- " dimensions, shape has only " + numDimensions () + "." );
245+ throw new ArrayIndexOutOfBoundsException (
246+ "Cannot take " + n + " dimensions, shape has only " + numDimensions () + "." );
247247 }
248248 long [] newDimensions = new long [n ];
249249 System .arraycopy (dimensionSizes , 0 , newDimensions , 0 , n );
@@ -257,18 +257,18 @@ public Shape tail() {
257257 }
258258
259259 /**
260- * Returns an n-dimensional Shape with the dimensions matching the last n dimensions
261- * of this Shape.
260+ * Returns an n-dimensional Shape with the dimensions matching the last n dimensions of this
261+ * Shape.
262262 *
263- * @param n the number of trailing dimensions to get, must be < = than
264- * {@link Shape#numDimensions()}
263+ * @param n the number of trailing dimensions to get, must be < = than {@link
264+ * Shape#numDimensions()}
265265 * @return an n-dimensional shape with the dimensions matching the last n dimensions of this
266- * Shape, never null
266+ * Shape, never null
267267 */
268268 public Shape takeLast (int n ) {
269269 if (n > numDimensions ()) {
270- throw new ArrayIndexOutOfBoundsException ("Cannot take last " + n +
271- " dimensions, shape has only " + numDimensions () + "." );
270+ throw new ArrayIndexOutOfBoundsException (
271+ "Cannot take last " + n + " dimensions, shape has only " + numDimensions () + "." );
272272 }
273273 long [] newDimensions = new long [n ];
274274 System .arraycopy (dimensionSizes , numDimensions () - n , newDimensions , 0 , n );
@@ -280,8 +280,8 @@ public Shape takeLast(int n) {
280280 * {@link Shape#isUnknown()} must be {@code false}.
281281 *
282282 * @param firstDimension the dimension to prepend
283- * @return a new shape with the given dimension first, followed by this Shape's dimensions,
284- * never null
283+ * @return a new shape with the given dimension first, followed by this Shape's dimensions, never
284+ * null
285285 */
286286 public Shape prepend (long firstDimension ) {
287287 long [] newDimensions = new long [dimensionSizes .length + 1 ];
@@ -292,8 +292,8 @@ public Shape prepend(long firstDimension) {
292292 }
293293
294294 /**
295- * Returns a new Shape, with a new last dimension added. In order for this call to succeed,
296- * {@link Shape#isUnknown()} must be {@code false}.
295+ * Returns a new Shape, with a new last dimension added. In order for this call to succeed, {@link
296+ * Shape#isUnknown()} must be {@code false}.
297297 *
298298 * @param lastDimension the dimension to append
299299 * @return a new Shape with this Shape's dimensions followed by the given dimension, never null
@@ -307,38 +307,36 @@ public Shape append(long lastDimension) {
307307 }
308308
309309 /**
310- * Returns a new Shape, with another Shape's dimensions prepended.
311- * For both this Shape and the other Shape, {@link Shape#isUnknown()} must return false.
312- * E.g. {@code Shape.of(3,4).prepend(Shape.of(1,2)) => Shape.of(1,2,3,4) }
310+ * Returns a new Shape, with another Shape's dimensions prepended. For both this Shape and the
311+ * other Shape, {@link Shape#isUnknown()} must return false. E.g. {@code
312+ * Shape.of(3,4).prepend(Shape.of(1,2)) => Shape.of(1,2,3,4) }
313313 *
314314 * @param other another Shape, must not be {@code null}, must not be unknown
315- * @return A new Shape consisting of the given Shapes 's dimensions followed by this Shape's
316- * dimensions, never null
315+ * @return A new Shape consisting of the given Shape 's dimensions followed by this Shape's
316+ * dimensions, never null
317317 */
318318 public Shape prepend (Shape other ) {
319319 long [] newDimensions = new long [other .dimensionSizes .length + dimensionSizes .length ];
320- System .arraycopy (other .dimensionSizes , 0 ,
321- newDimensions , 0 , other .dimensionSizes .length );
322- System .arraycopy (dimensionSizes , 0 ,
323- newDimensions , other .dimensionSizes .length , dimensionSizes .length );
320+ System .arraycopy (other .dimensionSizes , 0 , newDimensions , 0 , other .dimensionSizes .length );
321+ System .arraycopy (
322+ dimensionSizes , 0 , newDimensions , other .dimensionSizes .length , dimensionSizes .length );
324323 return Shape .of (newDimensions );
325324 }
326325
327326 /**
328- * Returns a new Shape, with another Shapes' dimensions appended.
329- * For both this Shape and the other Shape, {@link Shape#isUnknown()} must return false.
330- * e.g. {@code Shape.of(3,4).append(Shape.of(1,2)) => Shape.of(3,4,1,2) }
327+ * Returns a new Shape, with another Shapes' dimensions appended. For both this Shape and the
328+ * other Shape, {@link Shape#isUnknown()} must return false. E.g. @code
329+ * Shape.of(3,4).append(Shape.of(1,2)) => Shape.of(3,4,1,2) }
331330 *
332331 * @param other another Shape, must not be {@code null}, must not be unknown
333- * @return A new Shape consisting of this Shapes 's dimensions followed by the given Shape's
334- * dimensions
332+ * @return A new Shape consisting of this Shape 's dimensions followed by the given Shape's
333+ * dimensions
335334 */
336335 public Shape append (Shape other ) {
337336 long [] newDimensions = new long [dimensionSizes .length + other .dimensionSizes .length ];
338- System .arraycopy (dimensionSizes , 0 ,
339- newDimensions , 0 , dimensionSizes .length );
340- System .arraycopy (other .dimensionSizes , 0 ,
341- newDimensions , dimensionSizes .length , other .dimensionSizes .length );
337+ System .arraycopy (dimensionSizes , 0 , newDimensions , 0 , dimensionSizes .length );
338+ System .arraycopy (
339+ other .dimensionSizes , 0 , newDimensions , dimensionSizes .length , other .dimensionSizes .length );
342340 return Shape .of (newDimensions );
343341 }
344342
@@ -355,4 +353,74 @@ private static long computeSize(long[] dimensionSizes) {
355353 }
356354 return computedSize ;
357355 }
356+
357+ /**
358+ * Determines whether another shape is compatible with this one.
359+ *
360+ * <p>
361+ *
362+ * <p>Two possibly-partially-defined shapes are compatible if there exists a fully-defined shape
363+ * that both shapes can represent. Thus, compatibility allows the shape inference code to reason
364+ * about partially-defined shapes. For example:
365+ *
366+ * <ul>
367+ * <li><code>Shape.unknown()</code> is compatible with all shapes.
368+ * <li><code>Shape(UNKNOWN_SIZE, UNKNOWN_SIZE)</code> is compatible with all two-dimensional
369+ * shapes, such as <code>Shape(32, 784)</code>, and also <code>Shape.unknown()</code>. It is
370+ * not compatible with, for example, <code>Shape(UNKNOWN_SIZE)</code> or <code>
371+ * Shape(UNKNOWN_SIZE, UNKNOWN_SIZE, UNKNOWN_SIZE)</code>.
372+ * <li><code>Shape(32, UNKNOWN_SIZE)</code> is compatible with all two-dimensional shapes with
373+ * size 32 in the 0th dimension, and also <code>Shape(UNKNOWN_SIZE, UNKNOWN_SIZE)</code> and
374+ * <code>Shape.unknown()</code>. It is not compatible with, for example, <code>Shape(32)
375+ * </code>, <code>Shape(32, UNKNOWN_SIZE, 1)</code> or <code>Shape(64, UNKNOWN_SIZE)</code>.
376+ * <li><code>Shape(32, 784)</code> is compatible with itself, and also <code>
377+ * Shape(32, UNKNOWN_SIZE)</code>, <code>Shape(UNKNOWN_SIZE, 784)</code>, <code>
378+ * Shape(UNKNOWN_SIZE, UNKNOWN_SIZE)</code> and <code>Shape.unknown()</code>. It is not
379+ * compatible with, for example, <code>Shape(32, 1, 784)</code> or <code>Shape(UNKNOWN_SIZE)
380+ * </code>.
381+ * </ul>
382+ *
383+ * <p>The compatibility relation is reflexive and symmetric, but not transitive. For example,
384+ * <code>Shape(32, 784)</code> is compatible with <code>Shape.unknown()</code>, and <code>
385+ * Shape.unknown()</code> is compatible with <code>Shape(4, 4)</code>, but <code>Shape(32, 784)
386+ * </code> is not compatible with <code>Shape(4, 4)</code>.
387+ *
388+ * <p>Compatibility is not the same as broadcasting. Compatible shapes must have the same number
389+ * of dimensions and for each dimension pair, one dimension has to equal the other dimensions or
390+ * at least one of the dimensions in the pair has to be UNKNOWN_SIZE.
391+ *
392+ * <p>Broadcasting allows different dimensions, but paired dimensions have to either be equal, or
393+ * one dimension must be 1. If one shape has less dimensions than another shape, the smaller shape
394+ * is "stretched" with dimensions of 1.
395+ *
396+ * @param shape The other shape
397+ * @return true, if the two shapes are compatible.
398+ */
399+ public boolean isCompatibleWith (Shape shape ) {
400+ if (!this .isUnknown () && !shape .isUnknown ()) {
401+ if (numDimensions () != shape .numDimensions ()) {
402+ return false ;
403+ }
404+ for (int i = 0 ; i < numDimensions (); i ++) {
405+ if (!isCompatible (size (i ), shape .size (i ))) {
406+ return false ;
407+ }
408+ }
409+ }
410+ return true ;
411+ }
412+
413+ /**
414+ * Test to see if two shape dimensions are compatible.
415+ *
416+ * <p>The dimensions are compatible if either dimension is <code>Shape.UNKNOWN_SIZE</code> or both
417+ * dimensions are equal
418+ *
419+ * @param dim the first dimension
420+ * @param otherDim the second dimension
421+ * @return true, if both dimensions are compatible
422+ */
423+ public static boolean isCompatible (long dim , long otherDim ) {
424+ return dim == Shape .UNKNOWN_SIZE || otherDim == Shape .UNKNOWN_SIZE || dim == otherDim ;
425+ }
358426}
0 commit comments