Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import java.util.Iterator;
import java.util.List;
import java.util.StringJoiner;
import org.tensorflow.Tensor;
import org.tensorflow.ndarray.NdArray;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.proto.framework.DataType;

/**
* Tensor helper methods.
Expand Down Expand Up @@ -56,42 +56,57 @@ public static String toString(Tensor tensor, Integer maxWidth) {
}
return String.valueOf(iterator.next().getObject());
}
return toString(iterator, shape, 0, maxWidth);
return toString(iterator, tensor.dataType(), shape, 0, maxWidth);
}

/**
* @param iterator an iterator over the scalars
* @param shape the shape of the tensor
* @param maxWidth the maximum width of the output in characters ({@code null} if unlimited).
* This limit may surpassed if the first or last element are too long.
* Convert an element of a tensor to string, in a way that may depend on the data type.
*
* @param dtype the tensor's data type
* @param data the element
* @return the element's string representation
*/
private static String elementToString(DataType dtype, Object data) {
if (dtype == DataType.DT_STRING) {
return '"' + data.toString() + '"';
} else {
return data.toString();
}
}

/**
* @param iterator an iterator over the scalars
* @param shape the shape of the tensor
* @param maxWidth the maximum width of the output in characters ({@code null} if unlimited). This limit may surpassed
* if the first or last element are too long.
* @param dimension the current dimension being processed
* @return the String representation of the tensor data at {@code dimension}
*/
private static String toString(Iterator<? extends NdArray<?>> iterator, Shape shape,
private static String toString(Iterator<? extends NdArray<?>> iterator, DataType dtype, Shape shape,
int dimension, Integer maxWidth) {
if (dimension < shape.numDimensions() - 1) {
StringJoiner joiner = new StringJoiner("\n", indent(dimension) + "[\n",
"\n" + indent(dimension) + "]");
for (long i = 0, size = shape.size(dimension); i < size; ++i) {
String element = toString(iterator, shape, dimension + 1, maxWidth);
String element = toString(iterator, dtype, shape, dimension + 1, maxWidth);
joiner.add(element);
}
return joiner.toString();
}
if (maxWidth == null) {
StringJoiner joiner = new StringJoiner(", ", indent(dimension) + "[", "]");
for (long i = 0, size = shape.size(dimension); i < size; ++i) {
String element = iterator.next().getObject().toString();
joiner.add(element);
Object element = iterator.next().getObject();
joiner.add(elementToString(dtype, element));
}
return joiner.toString();
}
List<Integer> lengths = new ArrayList<>();
StringJoiner joiner = new StringJoiner(", ", indent(dimension) + "[", "]");
int lengthBefore = "]".length();
for (long i = 0, size = shape.size(dimension); i < size; ++i) {
String element = iterator.next().getObject().toString();
joiner.add(element);
Object element = iterator.next().getObject();
joiner.add(elementToString(dtype, element));
int addedLength = joiner.length() - lengthBefore;
lengths.add(addedLength);
lengthBefore += addedLength;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -544,26 +544,26 @@ public void gracefullyFailCreationFromNullArrayForStringTensor() {

@Test
public void dataToString() {
try (TInt32 t = TInt32.tensorOf(StdArrays.ndCopyOf(new int[]{3, 0, 1}))) {
try (TInt32 t = TInt32.vectorOf(3, 0, 1)) {
String actual = t.dataToString();
assertEquals("[3, 0, 1]", actual);
}
try (TInt32 t = TInt32.tensorOf(StdArrays.ndCopyOf(new int[]{3, 0, 1}))) {
try (TInt32 t = TInt32.vectorOf(3, 0, 1)) {
String actual = t.dataToString(Tensor.maxWidth(5));
// Cannot remove first or last element
assertEquals("[3, 0, 1]", actual);
}
try (TInt32 t = TInt32.tensorOf(StdArrays.ndCopyOf(new int[]{3, 0, 1}))) {
try (TInt32 t = TInt32.vectorOf(3, 0, 1)) {
String actual = t.dataToString(Tensor.maxWidth(6));
// Do not insert ellipses if it increases the length
assertEquals("[3, 0, 1]", actual);
}
try (TInt32 t = TInt32.tensorOf(StdArrays.ndCopyOf(new int[]{3, 0, 1, 2}))) {
try (TInt32 t = TInt32.vectorOf(3, 0, 1, 2)) {
String actual = t.dataToString(Tensor.maxWidth(11));
// Limit may be surpassed if first or last element are too long
assertEquals("[3, ..., 2]", actual);
}
try (TInt32 t = TInt32.tensorOf(StdArrays.ndCopyOf(new int[]{3, 0, 1, 2}))) {
try (TInt32 t = TInt32.vectorOf(3, 0, 1, 2)) {
String actual = t.dataToString(Tensor.maxWidth(12));
assertEquals("[3, 0, 1, 2]", actual);
}
Expand All @@ -574,10 +574,27 @@ public void dataToString() {
+ " [3, 2, 1]\n"
+ "]", actual);
}
try (RawTensor t = TInt32.tensorOf(StdArrays.ndCopyOf(new int[]{3, 0, 1, 2})).asRawTensor()) {
try (RawTensor t = TInt32.vectorOf(3, 0, 1, 2).asRawTensor()) {
String actual = t.dataToString(Tensor.maxWidth(12));
assertEquals("[3, 0, 1, 2]", actual);
}
// different data types
try (RawTensor t = TFloat32.vectorOf(3.0101f, 0, 1.5f, 2).asRawTensor()) {
String actual = t.dataToString();
assertEquals("[3.0101, 0.0, 1.5, 2.0]", actual);
}
try (RawTensor t = TFloat64.vectorOf(3.0101, 0, 1.5, 2).asRawTensor()) {
String actual = t.dataToString();
assertEquals("[3.0101, 0.0, 1.5, 2.0]", actual);
}
try (RawTensor t = TBool.vectorOf(true, true, false, true).asRawTensor()) {
String actual = t.dataToString();
assertEquals("[true, true, false, true]", actual);
}
try (RawTensor t = TString.vectorOf("a", "b", "c").asRawTensor()) {
String actual = t.dataToString();
assertEquals("[\"a\", \"b\", \"c\"]", actual);
}
}

// Workaround for cross compiliation
Expand Down