-
Notifications
You must be signed in to change notification settings - Fork 715
Open
Labels
bugUnexpected behaviour that should be corrected (type)Unexpected behaviour that should be corrected (type)
Description
🐞Describing the bug
CoreML BatchNorm3d crashes CoreML process
To Reproduce
import torch
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.norm = torch.nn.BatchNorm3d(3)
def forward(self, x):
return self.norm(x)
model = Model()
inputs = (
torch.randn(1, 3, 4, 4, 4),
)
eager_outputs = model(*inputs)
print(f"Eager: {eager_outputs.shape} {eager_outputs}")
ep = torch.export.export(model.eval(), inputs)
import coremltools as ct
import numpy as np
ep = ep.run_decompositions({})
mlmodel = ct.convert(ep)
coreml_inputs = mlmodel.get_spec().description.input
coreml_outputs = mlmodel.get_spec().description.output
predict_inputs = {str(ct_in.name): pt_in.detach().cpu().numpy().astype(np.int32) for ct_in, pt_in in zip(coreml_inputs, inputs)}
out = mlmodel.predict(predict_inputs)
print("CoremL", out)
Output is:
loc("tensor<fp16, [1, 3, 4, 4, 4]> _native_batch_norm_legit_no_training_cast_fp16 = batch_norm(beta = tensor<fp16, [3]>([0, 0, 0]), epsilon = fp16(1.00135803e-05), gamma = tensor<fp16, [3]>([1, 1, 1]), mean = tensor<fp16, [3]>([-0.012588501, 0.0046005249, 0.016494751]), variance = tensor<fp16, [3]>([1.00292969, 1.00195312, 1.01855469]), x = x_to_fp16)[milId = uint64(1), name = string(\22_native_batch_norm_legit_no_training_cast_fp16\22)]; - /private/var/folders/lw/phxpy6k10ll388xs18hyq1cr0000gn/T/tmp63_bi7tb.mlmodelc/model.mil":12:12): error: output type 'tensor<1x3x4x4x4xf16>' and mean type 'tensor<1x0x1x1x601354336xf16>' are not broadcast compatible
LLVM ERROR: Failed to infer result type(s).
zsh: abort python test.py
/opt/miniconda3/envs/op-et/lib/python3.10/multiprocessing/resource_tracker.py:224: UserWarning: resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown
warnings.warn('resource_tracker: There appear to be %d '
System environment (please complete the following information):
- coremltools version: 8.3
- OS (e.g. MacOS version or Linux type): macOS15
Metadata
Metadata
Assignees
Labels
bugUnexpected behaviour that should be corrected (type)Unexpected behaviour that should be corrected (type)