Skip to content

Commit 419f91c

Browse files
committed
Fix unknown_value=np.nan in OrdinalEncoder
Signed-off-by: xadupre <[email protected]>
1 parent 576bbb1 commit 419f91c

File tree

2 files changed

+64
-8
lines changed

2 files changed

+64
-8
lines changed

skl2onnx/operator_converters/ordinal_encoder.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,20 @@ def convert_sklearn_ordinal_encoder(
2626
dimension_idx = 0
2727

2828
# handle the 'handle_unknown=use_encoded_value' case
29+
use_float = (
30+
False
31+
if ordinal_op.unknown_value is None
32+
else isinstance(ordinal_op.unknown_value, float)
33+
or np.isnan(ordinal_op.unknown_value)
34+
)
2935
default_value = (
30-
None if ordinal_op.handle_unknown == "error" else int(ordinal_op.unknown_value)
36+
None
37+
if ordinal_op.handle_unknown == "error"
38+
else (
39+
float(ordinal_op.unknown_value)
40+
if use_float
41+
else int(ordinal_op.unknown_value)
42+
)
3143
)
3244

3345
for categories in ordinal_op.categories_:
@@ -103,24 +115,28 @@ def convert_sklearn_ordinal_encoder(
103115
)
104116

105117
# hanlde encoded_missing_value
118+
key = "values_floats" if use_float else "values_int64s"
119+
dtype = np.float32 if use_float else np.int64
106120
if not np.isnan(ordinal_op.encoded_missing_value) and (
107121
isinstance(categories[-1], float) and np.isnan(categories[-1])
108122
):
109123
# sklearn always places np.nan as the last entry
110-
# in its cathegories if it was in the training data
124+
# in its categories if it was in the training data
111125
# => we simply add the 'ordinal_op.encoded_missing_value'
112126
# as our last entry in 'values_int64s' if it was in the training data
113127
encoded_missing_value = np.array(
114128
[int(ordinal_op.encoded_missing_value)]
115-
).astype(np.int64)
116-
attrs["values_int64s"] = np.concatenate(
117-
(np.arange(len(categories) - 1).astype(np.int64), encoded_missing_value)
129+
).astype(dtype)
130+
attrs[key] = np.concatenate(
131+
(np.arange(len(categories) - 1).astype(dtype), encoded_missing_value)
118132
)
119133
else:
120-
attrs["values_int64s"] = np.arange(len(categories)).astype(np.int64)
134+
attrs[key] = np.arange(len(categories)).astype(dtype)
121135

122-
if default_value:
123-
attrs["default_int64"] = default_value
136+
if default_value or (
137+
isinstance(default_value, float) and np.isnan(default_value)
138+
):
139+
attrs["default_float" if use_float else "default_int64"] = default_value
124140

125141
result.append(scope.get_unique_variable_name("ordinal_output"))
126142
label_encoder_output = scope.get_unique_variable_name("label_encoder")

tests/test_sklearn_ordinal_encoder.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,46 @@ def test_ordinal_encoder_pipeline_string_int64(self):
379379
)
380380
assert_almost_equal(expected, got[0].ravel())
381381

382+
@unittest.skipIf(
383+
not ordinal_encoder_support(),
384+
reason="OrdinalEncoder was not available before 0.20",
385+
)
386+
def test_model_ordinal_encoder_unknown_value_nan(self):
387+
from onnxruntime import InferenceSession
388+
389+
model = OrdinalEncoder(handle_unknown="use_encoded_value", unknown_value=np.nan)
390+
data = np.array([["a"], ["b"], ["c"], ["d"]], dtype=np.object_)
391+
data_with_missing_value = np.array(
392+
[["a"], ["b"], ["c"], ["d"], [np.nan], ["e"], [None]], dtype=np.object_
393+
)
394+
395+
model.fit(data)
396+
# 'np.nan','e' and 'None' become 42.
397+
expected = model.transform(data_with_missing_value)
398+
399+
model_onnx = convert_sklearn(
400+
model,
401+
"scikit-learn ordinal encoder",
402+
[("input", StringTensorType([None, 1]))],
403+
target_opset=TARGET_OPSET,
404+
)
405+
self.assertIsNotNone(model_onnx)
406+
dump_data_and_model(
407+
data, model, model_onnx, basename="SklearnOrdinalEncoderUnknownValue"
408+
)
409+
410+
sess = InferenceSession(
411+
model_onnx.SerializeToString(), providers=["CPUExecutionProvider"]
412+
)
413+
got = sess.run(
414+
None,
415+
{
416+
"input": data_with_missing_value,
417+
},
418+
)
419+
420+
assert_almost_equal(expected.reshape(-1), got[0].reshape(-1))
421+
382422

383423
if __name__ == "__main__":
384424
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)