Skip to content

Conversation

@DrMicrobit
Copy link

@DrMicrobit DrMicrobit commented Jul 9, 2025

This PR resolves #371: M2 Mac: Runtime error in training of model after call to torchinfo.summary()

For torchinfo 1.8 on a Mac with M2 chip, the following code resulted in a runtime error:

device = torch.accelerator.current_accelerator()
model = nn.Sequential(nn.Flatten(), nn.Linear(3072, 10)).to(device)
summary(model, input_size=(batch_size, 3, 32, 32))
...
out = model(data)

with the error message

RuntimeError: Tensor for argument weight is on cpu but expected on mps

The same code ran fine on Linux with a Nvidia card.

Cause of bug:
In torchinfo.py, the function get_device() seems to be focused on recognising only CUDA as accelerator, whereas other platforms may have different accelerators. E.g., M-chip Macs have "mps".
This apparently leads to torchinfo pushing the model to the "cpu" when device= was not given in the call to summary(), which then leads to a runtime error during model training (or evaluation) when the data is on the accelerator and the model (or parts of it) are on the CPU.

Bug fix:
I have create a PR that should fix the bug for any accelerator recognised by PyTorch.

New behaviour of get_device():
Unchanged:

  • If input_data is given, the device should not be changed (to allow for multi-device models, etc.)

Changed:

  • Otherwise gets device of first parameter of model and returns it,
  • otherwise returns current accelerator if it is available,
  • otherwise returns cpu.

Old version failed to recognise non-cuda accelerators,
which led to bugs when torchinfo.summary() was called
without "device=" parameter on, e.g., Macs with M-chips,
where the accelerator is "mps".

New version:
- returns device of first parameter of model if present
- else queries torch for an available accelerator and returns that
- else returns "cpu"
Had left "Any" data type hint from my testing code
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

M2 Mac: Runtime error in training of model after call to torchinfo.summary()

1 participant