Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion onnxruntime/core/providers/cpu/nn/batch_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,11 @@ class BatchNorm : public OpKernel {
const TensorShape& x_shape = X->Shape();
Tensor* Y = p_op_kernel_context->Output(0, x_shape);

// X shape is [N, C, D1, D2, ... Dn], but it can also be 1-D according to onnx spec:
// "The op also accepts single dimension input of size N in which case C is assumed to be 1"
const auto& dims_vec = x_shape.GetDims();
const size_t N = onnxruntime::narrow<size_t>(dims_vec[0]);
const size_t C = onnxruntime::narrow<size_t>(dims_vec[1]); // assume NCHW as per the spec
const size_t C = dims_vec.size() == 1 ? 1 : onnxruntime::narrow<size_t>(dims_vec[1]);

// calculate sample_size (per individual channel)
size_t sample_size = 1;
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/providers/cpu/nn/batch_norm_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ class BatchNormHelper {
// NHWC dependent shape: X
// All other shapes are assumed to be in NCHW layout?
const auto& x_dims = X->Shape().GetDims();
if (x_dims.size() < 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input X: NumDimensions() < 1");
}

// If x_dims size < 2, num_channels defaults to 1.
int64_t num_channels;
Expand Down
Loading