Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,12 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install mypy pytest pytest-cov
pip install torch==${{ matrix.pytorch-version }} torchvision
pip install transformers
pip install torch==${{ matrix.pytorch-version }} torchvision transformers
pip install compressai
- name: mypy
if: ${{ matrix.pytorch-version == '1.13' }}
run: |
mypy .
mypy --install-types --non-interactive .
- name: pytest
if: ${{ matrix.pytorch-version == '1.13' }}
run: |
Expand Down
2 changes: 1 addition & 1 deletion profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import random

import torchvision # type: ignore[import] # pylint: disable=unused-import # noqa
from tqdm import trange # type: ignore[import] # pylint: disable=unused-import # noqa
from tqdm import trange # pylint: disable=unused-import # noqa

from torchinfo import summary # pylint: disable=unused-import # noqa

Expand Down
1 change: 0 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ torchinfo = py.typed
[mypy]
strict = True
implicit_reexport = True
show_error_codes = True
enable_error_code = ignore-without-code

[pylint.main]
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from setuptools import setup # type: ignore[import]
from setuptools import setup

setup()
55 changes: 26 additions & 29 deletions torchinfo/layer_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,38 +93,10 @@ def calculate_size(
Returns the corrected shape of `inputs` and the size of
a single element in bytes.
"""

def nested_list_size(
inputs: Sequence[Any] | torch.Tensor,
) -> tuple[list[int], int]:
"""Flattens nested list size."""

if hasattr(inputs, "tensors"):
size, elem_bytes = nested_list_size(inputs.tensors)
elif isinstance(inputs, torch.Tensor):
size, elem_bytes = list(inputs.size()), inputs.element_size()
elif not hasattr(inputs, "__getitem__") or not inputs:
size, elem_bytes = [], 0
elif isinstance(inputs, dict):
size, elem_bytes = nested_list_size(list(inputs.values()))
elif (
hasattr(inputs, "size")
and callable(inputs.size)
and hasattr(inputs, "element_size")
and callable(inputs.element_size)
):
size, elem_bytes = list(inputs.size()), inputs.element_size()
elif isinstance(inputs, (list, tuple)):
size, elem_bytes = nested_list_size(inputs[0])
else:
size, elem_bytes = [], 0

return size, elem_bytes

if inputs is None:
size, elem_bytes = [], 0

# pack_padded_seq and pad_packed_seq store feature into data attribute
# pack_padded_seq and pad_packed_seq store feature into data attribute
elif (
isinstance(inputs, (list, tuple)) and inputs and hasattr(inputs[0], "data")
):
Expand Down Expand Up @@ -337,6 +309,31 @@ def leftover_trainable_params(self) -> int:
)


def nested_list_size(inputs: Sequence[Any] | torch.Tensor) -> tuple[list[int], int]:
"""Flattens nested list size."""
if hasattr(inputs, "tensors"):
size, elem_bytes = nested_list_size(inputs.tensors)
elif isinstance(inputs, torch.Tensor):
size, elem_bytes = list(inputs.size()), inputs.element_size()
elif not hasattr(inputs, "__getitem__") or not inputs:
size, elem_bytes = [], 0
elif isinstance(inputs, dict):
size, elem_bytes = nested_list_size(list(inputs.values()))
elif (
hasattr(inputs, "size")
and callable(inputs.size)
and hasattr(inputs, "element_size")
and callable(inputs.element_size)
):
size, elem_bytes = list(inputs.size()), inputs.element_size()
elif isinstance(inputs, (list, tuple)):
size, elem_bytes = nested_list_size(inputs[0])
else:
size, elem_bytes = [], 0

return size, elem_bytes


def prod(num_list: Iterable[int] | torch.Size) -> int:
result = 1
if isinstance(num_list, Iterable):
Expand Down
18 changes: 6 additions & 12 deletions torchinfo/model_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def __init__(
self.total_params, self.trainable_params = 0, 0
self.total_param_bytes, self.total_output_bytes = 0, 0

# TODO: Figure out why the below functions using max() are ever 0
# (they should always be non-negative), and remove the call to max().
for layer_info in summary_list:
if layer_info.is_leaf_layer:
self.total_mult_adds += layer_info.macs
Expand All @@ -33,24 +35,16 @@ def __init__(
self.total_output_bytes += layer_info.output_bytes * 2
if layer_info.is_recursive:
continue
self.total_params += (
layer_info.num_params if layer_info.num_params > 0 else 0
)
self.total_params += max(layer_info.num_params, 0)
self.total_param_bytes += layer_info.param_bytes
self.trainable_params += (
layer_info.trainable_params
if layer_info.trainable_params > 0
else 0
)
self.trainable_params += max(layer_info.trainable_params, 0)
else:
if layer_info.is_recursive:
continue
leftover_params = layer_info.leftover_params()
leftover_trainable_params = layer_info.leftover_trainable_params()
self.total_params += leftover_params if leftover_params > 0 else 0
self.trainable_params += (
leftover_trainable_params if leftover_trainable_params > 0 else 0
)
self.total_params += max(leftover_params, 0)
self.trainable_params += max(leftover_trainable_params, 0)
self.formatting.set_layer_name_width(summary_list)

def __repr__(self) -> str:
Expand Down