Skip to content

Commit 8170301

Browse files
committed
merge conflicts
Signed-off-by: xadupre <[email protected]>
2 parents 50e1f67 + 2462f35 commit 8170301

File tree

3 files changed

+118
-4
lines changed

3 files changed

+118
-4
lines changed

CHANGELOGS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
* Fixes unknown_value=np.nan in OrdinalEncoder
66
[#1198](https://github.com/onnx/sklearn-onnx/issues/1198)
7+
* Enhance OrdinalEncoder conversion to handle infrequent categories
8+
[#1195](https://github.com/onnx/sklearn-onnx/issues/1195)
79

810
## 1.19.1
911

skl2onnx/operator_converters/ordinal_encoder.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,16 @@ def convert_sklearn_ordinal_encoder(
4646
if len(categories) == 0:
4747
continue
4848

49+
if (
50+
hasattr(ordinal_op, "_infrequent_enabled")
51+
and ordinal_op._infrequent_enabled
52+
):
53+
default_to_infrequent_mappings = ordinal_op._default_to_infrequent_mappings[
54+
input_idx
55+
]
56+
else:
57+
default_to_infrequent_mappings = None
58+
4959
current_input = operator.inputs[input_idx]
5060
if current_input.get_second_dimension() == 1:
5161
feature_column = current_input
@@ -127,11 +137,28 @@ def convert_sklearn_ordinal_encoder(
127137
encoded_missing_value = np.array(
128138
[int(ordinal_op.encoded_missing_value)]
129139
).astype(dtype)
130-
attrs[key] = np.concatenate(
131-
(np.arange(len(categories) - 1).astype(dtype), encoded_missing_value)
132-
)
140+
141+
# handle max_categories or min_frequency
142+
if default_to_infrequent_mappings is not None:
143+
attrs[key] = np.concatenate(
144+
(
145+
np.array(default_to_infrequent_mappings, dtype=dtype),
146+
encoded_missing_value,
147+
)
148+
)
149+
else:
150+
attrs[key] = np.concatenate(
151+
(
152+
np.arange(len(categories) - 1).astype(dtype),
153+
encoded_missing_value,
154+
)
155+
)
133156
else:
134-
attrs[key] = np.arange(len(categories)).astype(dtype)
157+
# handle max_categories or min_frequency
158+
if default_to_infrequent_mappings is not None:
159+
attrs[key] = np.array(default_to_infrequent_mappings, dtype=dtype)
160+
else:
161+
attrs[key] = np.arange(len(categories)).astype(dtype)
135162

136163
if default_value or (
137164
isinstance(default_value, float) and np.isnan(default_value)

tests/test_sklearn_ordinal_encoder.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,11 @@ def set_output_support():
4040
return pv.Version(vers) >= pv.Version("1.2")
4141

4242

43+
def max_categories_support():
44+
vers = ".".join(sklearn_version.split(".")[:2])
45+
return pv.Version(vers) >= pv.Version("1.3")
46+
47+
4348
class TestSklearnOrdinalEncoderConverter(unittest.TestCase):
4449
@unittest.skipIf(
4550
not ordinal_encoder_support(),
@@ -379,6 +384,86 @@ def test_ordinal_encoder_pipeline_string_int64(self):
379384
)
380385
assert_almost_equal(expected, got[0].ravel())
381386

387+
@unittest.skipIf(
388+
not max_categories_support(),
389+
reason="OrdinalEncoder supports max_categories and min_frequencey since 1.3",
390+
)
391+
def test_model_ordinal_encoder_max_categories(self):
392+
from onnxruntime import InferenceSession
393+
394+
model = OrdinalEncoder(max_categories=4)
395+
data = np.array(
396+
[["a"], ["b"], ["c"], ["d"], ["a"], ["b"], ["c"], ["e"]], dtype=np.object_
397+
)
398+
399+
expected = model.fit_transform(data)
400+
401+
model_onnx = convert_sklearn(
402+
model,
403+
"scikit-learn ordinal encoder",
404+
[("input", StringTensorType([None, 1]))],
405+
target_opset=TARGET_OPSET,
406+
)
407+
self.assertIsNotNone(model_onnx)
408+
dump_data_and_model(
409+
data,
410+
model,
411+
model_onnx,
412+
basename="SklearnOrdinalEncoderMaxCategories",
413+
)
414+
415+
sess = InferenceSession(
416+
model_onnx.SerializeToString(), providers=["CPUExecutionProvider"]
417+
)
418+
got = sess.run(
419+
None,
420+
{
421+
"input": data,
422+
},
423+
)
424+
425+
assert_almost_equal(expected.reshape(-1), got[0].reshape(-1))
426+
427+
@unittest.skipIf(
428+
not max_categories_support(),
429+
reason="OrdinalEncoder supports max_categories and min_frequencey since 1.3",
430+
)
431+
def test_model_ordinal_encoder_min_frequency(self):
432+
from onnxruntime import InferenceSession
433+
434+
model = OrdinalEncoder(min_frequency=2)
435+
data = np.array(
436+
[["a"], ["b"], ["c"], ["d"], ["a"], ["b"], ["c"], ["e"]], dtype=np.object_
437+
)
438+
439+
expected = model.fit_transform(data)
440+
441+
model_onnx = convert_sklearn(
442+
model,
443+
"scikit-learn ordinal encoder",
444+
[("input", StringTensorType([None, 1]))],
445+
target_opset=TARGET_OPSET,
446+
)
447+
self.assertIsNotNone(model_onnx)
448+
dump_data_and_model(
449+
data,
450+
model,
451+
model_onnx,
452+
basename="SklearnOrdinalEncoderMinFrequency",
453+
)
454+
455+
sess = InferenceSession(
456+
model_onnx.SerializeToString(), providers=["CPUExecutionProvider"]
457+
)
458+
got = sess.run(
459+
None,
460+
{
461+
"input": data,
462+
},
463+
)
464+
465+
assert_almost_equal(expected.reshape(-1), got[0].reshape(-1))
466+
382467
@unittest.skipIf(
383468
not ordinal_encoder_support(),
384469
reason="OrdinalEncoder was not available before 0.20",

0 commit comments

Comments
 (0)