Skip to content

Bugfix/get device #370

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open

Conversation

DrMicrobit
Copy link

@DrMicrobit DrMicrobit commented Jul 9, 2025

This PR resolves #371: M2 Mac: Runtime error in training of model after call to torchinfo.summary()

For torchinfo 1.8 on a Mac with M2 chip, the following code resulted in a runtime error:

device = torch.accelerator.current_accelerator()
model = nn.Sequential(nn.Flatten(), nn.Linear(3072, 10)).to(device)
summary(model, input_size=(batch_size, 3, 32, 32))
...
out = model(data)

with the error message

RuntimeError: Tensor for argument weight is on cpu but expected on mps

The same code ran fine on Linux with a Nvidia card.

Cause of bug:
In torchinfo.py, the function get_device() seems to be focused on recognising only CUDA as accelerator, whereas other platforms may have different accelerators. E.g., M-chip Macs have "mps".
This apparently leads to torchinfo pushing the model to the "cpu" when device= was not given in the call to summary(), which then leads to a runtime error during model training (or evaluation) when the data is on the accelerator and the model (or parts of it) are on the CPU.

Bug fix:
I have create a PR that should fix the bug for any accelerator recognised by PyTorch.

New behaviour of get_device():
Unchanged:

  • If input_data is given, the device should not be changed (to allow for multi-device models, etc.)

Changed:

  • Otherwise gets device of first parameter of model and returns it,
  • otherwise returns current accelerator if it is available,
  • otherwise returns cpu.

Old version failed to recognise non-cuda accelerators,
which led to bugs when torchinfo.summary() was called
without "device=" parameter on, e.g., Macs with M-chips,
where the accelerator is "mps".

New version:
- returns device of first parameter of model if present
- else queries torch for an available accelerator and returns that
- else returns "cpu"
Had left "Any" data type hint from my testing code
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

M2 Mac: Runtime error in training of model after call to torchinfo.summary()
1 participant