File tree Expand file tree Collapse file tree 2 files changed +3
-1
lines changed
main/scala/ml/combust/mleap/core/types
test/scala/ml/combust/mleap/core/types Expand file tree Collapse file tree 2 files changed +3
-1
lines changed Original file line number Diff line number Diff line change @@ -143,7 +143,7 @@ object Casting {
143
143
}
144
144
}
145
145
}
146
- case (tt : TensorType , _ : ScalarType ) if tt.dimensions.exists(_ .isEmpty) =>
146
+ case (tt : TensorType , _ : ScalarType ) if tt.dimensions.exists(dimensions => dimensions .isEmpty || dimensions.product == 1 ) =>
147
147
baseCast(from.base, to.base).map {
148
148
_.map {
149
149
c => (v : Any ) => c(v.asInstanceOf [Tensor [_]](0 ))
Original file line number Diff line number Diff line change @@ -197,6 +197,7 @@ class CastingSpec extends FunSpec {
197
197
val tc = Casting .cast(ScalarType (from), TensorType (to, Some (Seq ()))).getOrElse(Success ((v : Any ) => v)).get
198
198
val ct = Casting .cast(TensorType (from, Some (Seq ())), ScalarType (to)).getOrElse(Success ((v : Any ) => v)).get
199
199
val lct = Casting .cast(TensorType (from, Some (Seq (expectedList.length))), ListType (to)).getOrElse(Success ((v : Any ) => v)).get
200
+ val ct1 = Casting .cast(TensorType (from, Some (Seq (1 ))), ScalarType (to)).getOrElse(Success ((v : Any ) => v)).get
200
201
201
202
assert(c(fromTensor) == expectedTensor)
202
203
assertThrows[NullPointerException ](oc(null ))
@@ -205,6 +206,7 @@ class CastingSpec extends FunSpec {
205
206
assert(tc(fromValue) == expectedScalarTensor)
206
207
assert(ct(fromScalarTensor) == expectedValue)
207
208
assert(lct(fromListTensor) == expectedList)
209
+ assert(ct1(fromTensor) == expectedValue)
208
210
}
209
211
210
212
}
You can’t perform that action at this time.
0 commit comments