@@ -40,6 +40,11 @@ def set_output_support():
40
40
return pv .Version (vers ) >= pv .Version ("1.2" )
41
41
42
42
43
+ def max_categories_support ():
44
+ vers = "." .join (sklearn_version .split ("." )[:2 ])
45
+ return pv .Version (vers ) >= pv .Version ("1.3" )
46
+
47
+
43
48
class TestSklearnOrdinalEncoderConverter (unittest .TestCase ):
44
49
@unittest .skipIf (
45
50
not ordinal_encoder_support (),
@@ -379,6 +384,86 @@ def test_ordinal_encoder_pipeline_string_int64(self):
379
384
)
380
385
assert_almost_equal (expected , got [0 ].ravel ())
381
386
387
+ @unittest .skipIf (
388
+ not max_categories_support (),
389
+ reason = "OrdinalEncoder supports max_categories and min_frequencey since 1.3" ,
390
+ )
391
+ def test_model_ordinal_encoder_max_categories (self ):
392
+ from onnxruntime import InferenceSession
393
+
394
+ model = OrdinalEncoder (max_categories = 4 )
395
+ data = np .array (
396
+ [["a" ], ["b" ], ["c" ], ["d" ], ["a" ], ["b" ], ["c" ], ["e" ]], dtype = np .object_
397
+ )
398
+
399
+ expected = model .fit_transform (data )
400
+
401
+ model_onnx = convert_sklearn (
402
+ model ,
403
+ "scikit-learn ordinal encoder" ,
404
+ [("input" , StringTensorType ([None , 1 ]))],
405
+ target_opset = TARGET_OPSET ,
406
+ )
407
+ self .assertIsNotNone (model_onnx )
408
+ dump_data_and_model (
409
+ data ,
410
+ model ,
411
+ model_onnx ,
412
+ basename = "SklearnOrdinalEncoderMaxCategories" ,
413
+ )
414
+
415
+ sess = InferenceSession (
416
+ model_onnx .SerializeToString (), providers = ["CPUExecutionProvider" ]
417
+ )
418
+ got = sess .run (
419
+ None ,
420
+ {
421
+ "input" : data ,
422
+ },
423
+ )
424
+
425
+ assert_almost_equal (expected .reshape (- 1 ), got [0 ].reshape (- 1 ))
426
+
427
+ @unittest .skipIf (
428
+ not max_categories_support (),
429
+ reason = "OrdinalEncoder supports max_categories and min_frequencey since 1.3" ,
430
+ )
431
+ def test_model_ordinal_encoder_min_frequency (self ):
432
+ from onnxruntime import InferenceSession
433
+
434
+ model = OrdinalEncoder (min_frequency = 2 )
435
+ data = np .array (
436
+ [["a" ], ["b" ], ["c" ], ["d" ], ["a" ], ["b" ], ["c" ], ["e" ]], dtype = np .object_
437
+ )
438
+
439
+ expected = model .fit_transform (data )
440
+
441
+ model_onnx = convert_sklearn (
442
+ model ,
443
+ "scikit-learn ordinal encoder" ,
444
+ [("input" , StringTensorType ([None , 1 ]))],
445
+ target_opset = TARGET_OPSET ,
446
+ )
447
+ self .assertIsNotNone (model_onnx )
448
+ dump_data_and_model (
449
+ data ,
450
+ model ,
451
+ model_onnx ,
452
+ basename = "SklearnOrdinalEncoderMinFrequency" ,
453
+ )
454
+
455
+ sess = InferenceSession (
456
+ model_onnx .SerializeToString (), providers = ["CPUExecutionProvider" ]
457
+ )
458
+ got = sess .run (
459
+ None ,
460
+ {
461
+ "input" : data ,
462
+ },
463
+ )
464
+
465
+ assert_almost_equal (expected .reshape (- 1 ), got [0 ].reshape (- 1 ))
466
+
382
467
@unittest .skipIf (
383
468
not ordinal_encoder_support (),
384
469
reason = "OrdinalEncoder was not available before 0.20" ,
0 commit comments