Skip to content

Incorrect total_param_bytes and model size for models with params in non-leaf layers #366

@RaySteak

Description

@RaySteak

Describe the bug
When using summary to calculate model size on models that have parameters in nested layers, the output is incorrect.

To Reproduce
Print total_param_bytes of summary on any model that has params in nested layers:

import torch
import torch.nn as nn

from torchinfo import summary

class Layer(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(3, 5)
    
    def forward(self, x):
        return self.fc.forward(x)

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = Layer()
        self.fc = nn.Linear(5, 5)
    
    def forward(self, x):
        x = self.layer.forward(x)
        return self.fc(x)

model = Net()

s = summary(model, input_size=(32, 3), verbose=0)
print(s.total_param_bytes)

Expected behavior
The output number of bytes is 120, when it should be 200. The parameters of model.layer are omitted.

Desktop (please complete the following information):

  • OS: Windows 11
  • PyTorch 2.6.0
  • Torchinfo 1.8.0

Additional context
The issue seems to stem from the fact that parameter sizes are not counted for non-leaf layers in ModelStatistics:

for layer_info in summary_list:
if layer_info.is_leaf_layer:
self.total_mult_adds += layer_info.macs
if layer_info.num_params > 0:
# x2 for gradients
self.total_output_bytes += layer_info.output_bytes * 2
if layer_info.is_recursive:
continue
self.total_params += max(layer_info.num_params, 0)
self.total_param_bytes += layer_info.param_bytes
self.trainable_params += max(layer_info.trainable_params, 0)
else:
if layer_info.is_recursive:
continue
leftover_params = layer_info.leftover_params()
leftover_trainable_params = layer_info.leftover_trainable_params()
self.total_params += max(leftover_params, 0)
self.trainable_params += max(leftover_trainable_params, 0)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions