Skip to content

Shape mismatch RuntimeError when summarizing model with dynamic expansion based on input shape #364

@noah-beepboop

Description

@noah-beepboop

The code dynamically changes shapes of tensors by referencing from the shapes of other tensors. The torchinfo tracer must not correctly track the shapes over this dynamic operation. I do this in the forward pass of GenerateNoiseTensor. I ran this in a jupyter notebook by the way through google colab

Here is my code:


class ConvBlock(nn.Module):
def init(self, input, output, blocktype):
super().init()
self.blocktype = blocktype
self.conv1 = nn.Conv2d(input, output, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(output, output, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(output)
self.relu = nn.ReLU()
self.bn2 = nn.BatchNorm2d(output)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.noise_injection = GenerateNoiseTensor(256)

def forward(self, image, timestep):
image = self.conv1(image)
image = self.bn1(image)
image = self.relu(image)
image = self.conv2(image)
image = self.bn2(image)
image = self.relu(image)
image += self.noise_injection(image.shape, timestep)

  if self.blocktype == "encoder":
    image = self.pool(image)
  elif self.blocktype == "decoder":
    image = self.up(image)

  return image

Ok so this is kinda goofy but I will explain

So you take the timestep which we will call t

Then you put it through this -> timestep_vector = sin(ax), sin(bx), cos(ax), cos(bx)

It lets u process u turn the 1D timestep into a 2D vector to put thru a NN

class SinusoidalEmbedding(nn.Module):
def init(self, dim):
super().init()
self.dim = dim # How big do you want the output vector to be? (256)

def forward(self, timestep):

timestep = timestep.view(1)
half_dim = self.dim // 2
frequency_factor = math.log(10000) / (half_dim - 1) # Comes up with a frequency factor for the sinusoidal functions (also use 10,000 bc it's standard whatever whatever)
frequencies = torch.exp(torch.arange(half_dim, device=device) * -frequency_factor) # Make the different frequencies according to the frequency factor (idk how this really works)
bx = timestep[:, None] * frequencies[None, :] # Calculate the timestep * frequency
embeddings = torch.cat((torch.sin(bx), torch.cos(bx)), dim=-1) # Processes thru the sin and cos functions
return embeddings

class GenerateNoiseTensor(nn.Module):
def init(self, dim):
super().init()
self.time_embed = SinusoidalEmbedding(dim)
self.dim = dim

self.processing_model = nn.Sequential(
    nn.Linear(dim, dim),
    nn.ReLU(),
    nn.Linear(dim, dim)
)

def forward(self, feature_shape, timestep):
time_embed = self.time_embed(timestep)
processed_embed = self.processing_model(time_embed)
output_noise_tensor = processed_embed.view(-1, self.dim, 1, 1) # [batch, dim, 1, 1]
output_noise_tensor = output_noise_tensor.expand(-1, -1, feature_shape[2], feature_shape[3])
return output_noise_tensor

OK so, stuff we have to specify: feature shape and timestep

class UNetModel(nn.Module):
def init(self): # define all the crap
super().init()
self.en1 = ConvBlock(3, 64, "encoder")
self.en2 = ConvBlock(64, 128, "encoder")
self.en3 = ConvBlock(128, 256, "encoder")
self.en4 = ConvBlock(256, 512, "encoder")
self.en5 = ConvBlock(512, 1024, "encoder")

self.bottleneck = ConvBlock(1024, 1024, "bottleneck")

self.de1 = ConvBlock(1024 + 1024, 512, "decoder")
self.de2 = ConvBlock(512 + 512, 256, "decoder")
self.de3 = ConvBlock(256 + 256, 128, "decoder")
self.de4 = ConvBlock(128 + 128, 64, "decoder")
self.de5 = ConvBlock(64 + 64, 3, "decoder")
self.final_conv = nn.Conv2d(3, 3, kernel_size=1) # Kernel size one bc that's what your supposed to do

def forward(self, image, timestep): # do all the crap
en1_out = self.en1(image, timestep)
en2_out = self.en2(en1_out, timestep)
en3_out = self.en3(en2_out, timestep)
en4_out = self.en4(en3_out, timestep)
en5_out = self.en5(en4_out, timestep)

bottleneck_out = self.bottleneck(en5_out)

de1_out = self.de1(torch.cat([bottleneck_out, en5_out], dim=1), timestep) # skip connection stuff :3
de2_out = self.de2(torch.cat([de1_out, en4_out], dim=1), timestep)
de3_out = self.de3(torch.cat([de2_out, en3_out], dim=1), timestep)
de4_out = self.de4(torch.cat([de3_out, en2_out], dim=1), timestep)
de5_out = self.de5(torch.cat([de4_out, en1_out], dim=1), timestep)
output = self.final_conv(de5_out)

from torchinfo import summary
import torch
import math

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = UNetModel().to(device)

batch_size = 1
dummy_image = torch.randn(batch_size, 3, 512, 512, device=device)

dummy_timestep = torch.randint(0, 1000, (batch_size, 1), device=device).float()
summary(model, input_data=(dummy_image, dummy_timestep), device=device)


Here is the error message:


RuntimeError Traceback (most recent call last)
/usr/local/lib/python3.11/dist-packages/torchinfo/torchinfo.py in forward_pass(model, x, batch_dim, cache_forward_pass, device, mode, **kwargs)
294 if isinstance(x, (list, tuple)):
--> 295 _ = model(*x, **kwargs)
296 elif isinstance(x, dict):

10 frames
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1738 else:
-> 1739 return self._call_impl(*args, **kwargs)
1740

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1844 try:
-> 1845 return inner()
1846 except Exception:

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in inner()
1792
-> 1793 result = forward_call(*args, **kwargs)
1794 if _global_forward_hooks or self._forward_hooks:

in forward(self, image, timestep)
95 def forward(self, image, timestep): # do all the crap
---> 96 en1_out = self.en1(image, timestep)
97 en2_out = self.en2(en1_out, timestep)

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1738 else:
-> 1739 return self._call_impl(*args, **kwargs)
1740

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1844 try:
-> 1845 return inner()
1846 except Exception:

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in inner()
1792
-> 1793 result = forward_call(*args, **kwargs)
1794 if _global_forward_hooks or self._forward_hooks:

in forward(self, image, timestep)
26 image = self.relu(image)
---> 27 image += self.noise_injection(image.shape, timestep)
28

RuntimeError: The size of tensor a (64) must match the size of tensor b (256) at non-singleton dimension 1

The above exception was the direct cause of the following exception:

RuntimeError Traceback (most recent call last)
in <cell line: 0>()
23 # Use torchinfo.summary with input_data
24 # Pass the dummy tensors as a tuple to input_data
---> 25 summary(model, input_data=(dummy_image, dummy_timestep), device=device)

/usr/local/lib/python3.11/dist-packages/torchinfo/torchinfo.py in summary(model, input_size, input_data, batch_dim, cache_forward_pass, col_names, col_width, depth, device, dtypes, mode, row_settings, verbose, **kwargs)
221 input_data, input_size, batch_dim, device, dtypes
222 )
--> 223 summary_list = forward_pass(
224 model, x, batch_dim, cache_forward_pass, device, model_mode, **kwargs
225 )

/usr/local/lib/python3.11/dist-packages/torchinfo/torchinfo.py in forward_pass(model, x, batch_dim, cache_forward_pass, device, mode, **kwargs)
302 except Exception as e:
303 executed_layers = [layer for layer in summary_list if layer.executed]
--> 304 raise RuntimeError(
305 "Failed to run torchinfo. See above stack traces for more details. "
306 f"Executed layers up to: {executed_layers}"

RuntimeError: Failed to run torchinfo. See above stack traces for more details. Executed layers up to: [Conv2d: 2, BatchNorm2d: 2, ReLU: 2, Conv2d: 2, BatchNorm2d: 2, ReLU: 2, GenerateNoiseTensor: 2, SinusoidalEmbedding: 3, Sequential: 3, Linear: 4, ReLU: 4, Linear: 4]

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions