Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR resolves #371: M2 Mac: Runtime error in training of model after call to torchinfo.summary()
For torchinfo 1.8 on a Mac with M2 chip, the following code resulted in a runtime error:
with the error message
The same code ran fine on Linux with a Nvidia card.
Cause of bug:
In torchinfo.py, the function
get_device()
seems to be focused on recognising only CUDA as accelerator, whereas other platforms may have different accelerators. E.g., M-chip Macs have "mps".This apparently leads to torchinfo pushing the model to the "cpu" when
device=
was not given in the call tosummary()
, which then leads to a runtime error during model training (or evaluation) when the data is on the accelerator and the model (or parts of it) are on the CPU.Bug fix:
I have create a PR that should fix the bug for any accelerator recognised by PyTorch.
New behaviour of get_device():
Unchanged:
Changed: