Skip to content

Commit 01fa0ce

Browse files
authored
Separate nested_list_size function, add some documentation, improve mypy for setuptools (#220)
1 parent c879e2a commit 01fa0ce

File tree

6 files changed

+36
-47
lines changed

6 files changed

+36
-47
lines changed

.github/workflows/test.yml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,12 @@ jobs:
5555
run: |
5656
python -m pip install --upgrade pip
5757
python -m pip install mypy pytest pytest-cov
58-
pip install torch==${{ matrix.pytorch-version }} torchvision
59-
pip install transformers
58+
pip install torch==${{ matrix.pytorch-version }} torchvision transformers
6059
pip install compressai
6160
- name: mypy
6261
if: ${{ matrix.pytorch-version == '1.13' }}
6362
run: |
64-
mypy .
63+
mypy --install-types --non-interactive .
6564
- name: pytest
6665
if: ${{ matrix.pytorch-version == '1.13' }}
6766
run: |

profiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import random
44

55
import torchvision # type: ignore[import] # pylint: disable=unused-import # noqa
6-
from tqdm import trange # type: ignore[import] # pylint: disable=unused-import # noqa
6+
from tqdm import trange # pylint: disable=unused-import # noqa
77

88
from torchinfo import summary # pylint: disable=unused-import # noqa
99

setup.cfg

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ torchinfo = py.typed
3333
[mypy]
3434
strict = True
3535
implicit_reexport = True
36-
show_error_codes = True
3736
enable_error_code = ignore-without-code
3837

3938
[pylint.main]

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from setuptools import setup # type: ignore[import]
1+
from setuptools import setup
22

33
setup()

torchinfo/layer_info.py

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -93,38 +93,10 @@ def calculate_size(
9393
Returns the corrected shape of `inputs` and the size of
9494
a single element in bytes.
9595
"""
96-
97-
def nested_list_size(
98-
inputs: Sequence[Any] | torch.Tensor,
99-
) -> tuple[list[int], int]:
100-
"""Flattens nested list size."""
101-
102-
if hasattr(inputs, "tensors"):
103-
size, elem_bytes = nested_list_size(inputs.tensors)
104-
elif isinstance(inputs, torch.Tensor):
105-
size, elem_bytes = list(inputs.size()), inputs.element_size()
106-
elif not hasattr(inputs, "__getitem__") or not inputs:
107-
size, elem_bytes = [], 0
108-
elif isinstance(inputs, dict):
109-
size, elem_bytes = nested_list_size(list(inputs.values()))
110-
elif (
111-
hasattr(inputs, "size")
112-
and callable(inputs.size)
113-
and hasattr(inputs, "element_size")
114-
and callable(inputs.element_size)
115-
):
116-
size, elem_bytes = list(inputs.size()), inputs.element_size()
117-
elif isinstance(inputs, (list, tuple)):
118-
size, elem_bytes = nested_list_size(inputs[0])
119-
else:
120-
size, elem_bytes = [], 0
121-
122-
return size, elem_bytes
123-
12496
if inputs is None:
12597
size, elem_bytes = [], 0
12698

127-
# pack_padded_seq and pad_packed_seq store feature into data attribute
99+
# pack_padded_seq and pad_packed_seq store feature into data attribute
128100
elif (
129101
isinstance(inputs, (list, tuple)) and inputs and hasattr(inputs[0], "data")
130102
):
@@ -337,6 +309,31 @@ def leftover_trainable_params(self) -> int:
337309
)
338310

339311

312+
def nested_list_size(inputs: Sequence[Any] | torch.Tensor) -> tuple[list[int], int]:
313+
"""Flattens nested list size."""
314+
if hasattr(inputs, "tensors"):
315+
size, elem_bytes = nested_list_size(inputs.tensors)
316+
elif isinstance(inputs, torch.Tensor):
317+
size, elem_bytes = list(inputs.size()), inputs.element_size()
318+
elif not hasattr(inputs, "__getitem__") or not inputs:
319+
size, elem_bytes = [], 0
320+
elif isinstance(inputs, dict):
321+
size, elem_bytes = nested_list_size(list(inputs.values()))
322+
elif (
323+
hasattr(inputs, "size")
324+
and callable(inputs.size)
325+
and hasattr(inputs, "element_size")
326+
and callable(inputs.element_size)
327+
):
328+
size, elem_bytes = list(inputs.size()), inputs.element_size()
329+
elif isinstance(inputs, (list, tuple)):
330+
size, elem_bytes = nested_list_size(inputs[0])
331+
else:
332+
size, elem_bytes = [], 0
333+
334+
return size, elem_bytes
335+
336+
340337
def prod(num_list: Iterable[int] | torch.Size) -> int:
341338
result = 1
342339
if isinstance(num_list, Iterable):

torchinfo/model_statistics.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ def __init__(
2525
self.total_params, self.trainable_params = 0, 0
2626
self.total_param_bytes, self.total_output_bytes = 0, 0
2727

28+
# TODO: Figure out why the below functions using max() are ever 0
29+
# (they should always be non-negative), and remove the call to max().
2830
for layer_info in summary_list:
2931
if layer_info.is_leaf_layer:
3032
self.total_mult_adds += layer_info.macs
@@ -33,24 +35,16 @@ def __init__(
3335
self.total_output_bytes += layer_info.output_bytes * 2
3436
if layer_info.is_recursive:
3537
continue
36-
self.total_params += (
37-
layer_info.num_params if layer_info.num_params > 0 else 0
38-
)
38+
self.total_params += max(layer_info.num_params, 0)
3939
self.total_param_bytes += layer_info.param_bytes
40-
self.trainable_params += (
41-
layer_info.trainable_params
42-
if layer_info.trainable_params > 0
43-
else 0
44-
)
40+
self.trainable_params += max(layer_info.trainable_params, 0)
4541
else:
4642
if layer_info.is_recursive:
4743
continue
4844
leftover_params = layer_info.leftover_params()
4945
leftover_trainable_params = layer_info.leftover_trainable_params()
50-
self.total_params += leftover_params if leftover_params > 0 else 0
51-
self.trainable_params += (
52-
leftover_trainable_params if leftover_trainable_params > 0 else 0
53-
)
46+
self.total_params += max(leftover_params, 0)
47+
self.trainable_params += max(leftover_trainable_params, 0)
5448
self.formatting.set_layer_name_width(summary_list)
5549

5650
def __repr__(self) -> str:

0 commit comments

Comments
 (0)