Skip to content

Commit 99939ef

Browse files
authored
Avoid concatenation if not needed (onnx#1110)
* Avoid concatenation if not needed Signed-off-by: Xavier Dupre <[email protected]> * change Signed-off-by: Xavier Dupre <[email protected]> --------- Signed-off-by: Xavier Dupre <[email protected]>
1 parent 90e3d86 commit 99939ef

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

CHANGELOGS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
## 1.18.0
44

5+
* Converter for OneHotEncoder does not add a concat operator if not needed,
6+
[#1110](https://github.com/onnx/sklearn-onnx/pull/1110)
57
* Function ``to_onnx`` now forces the main opset to be equal to the
68
value speficied by the user (parameter ``target_opset``),
79
[#1109](https://github.com/onnx/sklearn-onnx/pull/1109)

skl2onnx/operator_converters/one_hot_encoder.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,8 +238,11 @@ def convert_sklearn_one_hot_encoder(
238238
result.append(ohe_output)
239239
categories_len += len(categories)
240240

241-
concat_result_name = scope.get_unique_variable_name("concat_result")
242-
apply_concat(scope, result, concat_result_name, container, axis=-1)
241+
if len(result) == 1:
242+
concat_result_name = result[0]
243+
else:
244+
concat_result_name = scope.get_unique_variable_name("concat_result")
245+
apply_concat(scope, result, concat_result_name, container, axis=-1)
243246

244247
reshape_input = concat_result_name
245248
if np.issubdtype(ohe_op.dtype, np.signedinteger):

0 commit comments

Comments
 (0)