Skip to content

Commit 99e0a22

Browse files
authored
Allow casting a tensor of size 1 to scalar (#850)
1 parent f750480 commit 99e0a22

File tree

2 files changed

+3
-1
lines changed

2 files changed

+3
-1
lines changed

mleap-core/src/main/scala/ml/combust/mleap/core/types/Casting.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ object Casting {
143143
}
144144
}
145145
}
146-
case (tt: TensorType, _: ScalarType) if tt.dimensions.exists(_.isEmpty) =>
146+
case (tt: TensorType, _: ScalarType) if tt.dimensions.exists(dimensions => dimensions.isEmpty || dimensions.product == 1) =>
147147
baseCast(from.base, to.base).map {
148148
_.map {
149149
c => (v: Any) => c(v.asInstanceOf[Tensor[_]](0))

mleap-core/src/test/scala/ml/combust/mleap/core/types/CastingSpec.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ class CastingSpec extends FunSpec {
197197
val tc = Casting.cast(ScalarType(from), TensorType(to, Some(Seq()))).getOrElse(Success((v: Any) => v)).get
198198
val ct = Casting.cast(TensorType(from, Some(Seq())), ScalarType(to)).getOrElse(Success((v: Any) => v)).get
199199
val lct = Casting.cast(TensorType(from, Some(Seq(expectedList.length))), ListType(to)).getOrElse(Success((v: Any) => v)).get
200+
val ct1 = Casting.cast(TensorType(from, Some(Seq(1))), ScalarType(to)).getOrElse(Success((v: Any) => v)).get
200201

201202
assert(c(fromTensor) == expectedTensor)
202203
assertThrows[NullPointerException](oc(null))
@@ -205,6 +206,7 @@ class CastingSpec extends FunSpec {
205206
assert(tc(fromValue) == expectedScalarTensor)
206207
assert(ct(fromScalarTensor) == expectedValue)
207208
assert(lct(fromListTensor) == expectedList)
209+
assert(ct1(fromTensor) == expectedValue)
208210
}
209211

210212
}

0 commit comments

Comments
 (0)