-
Notifications
You must be signed in to change notification settings - Fork 131
Description
The code dynamically changes shapes of tensors by referencing from the shapes of other tensors. The torchinfo tracer must not correctly track the shapes over this dynamic operation. I do this in the forward pass of GenerateNoiseTensor. I ran this in a jupyter notebook by the way through google colab
Here is my code:
class ConvBlock(nn.Module):
def init(self, input, output, blocktype):
super().init()
self.blocktype = blocktype
self.conv1 = nn.Conv2d(input, output, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(output, output, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(output)
self.relu = nn.ReLU()
self.bn2 = nn.BatchNorm2d(output)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.noise_injection = GenerateNoiseTensor(256)
def forward(self, image, timestep):
image = self.conv1(image)
image = self.bn1(image)
image = self.relu(image)
image = self.conv2(image)
image = self.bn2(image)
image = self.relu(image)
image += self.noise_injection(image.shape, timestep)
if self.blocktype == "encoder":
image = self.pool(image)
elif self.blocktype == "decoder":
image = self.up(image)
return image
Ok so this is kinda goofy but I will explain
So you take the timestep which we will call t
Then you put it through this -> timestep_vector = sin(ax), sin(bx), cos(ax), cos(bx)
It lets u process u turn the 1D timestep into a 2D vector to put thru a NN
class SinusoidalEmbedding(nn.Module):
def init(self, dim):
super().init()
self.dim = dim # How big do you want the output vector to be? (256)
def forward(self, timestep):
timestep = timestep.view(1)
half_dim = self.dim // 2
frequency_factor = math.log(10000) / (half_dim - 1) # Comes up with a frequency factor for the sinusoidal functions (also use 10,000 bc it's standard whatever whatever)
frequencies = torch.exp(torch.arange(half_dim, device=device) * -frequency_factor) # Make the different frequencies according to the frequency factor (idk how this really works)
bx = timestep[:, None] * frequencies[None, :] # Calculate the timestep * frequency
embeddings = torch.cat((torch.sin(bx), torch.cos(bx)), dim=-1) # Processes thru the sin and cos functions
return embeddings
class GenerateNoiseTensor(nn.Module):
def init(self, dim):
super().init()
self.time_embed = SinusoidalEmbedding(dim)
self.dim = dim
self.processing_model = nn.Sequential(
nn.Linear(dim, dim),
nn.ReLU(),
nn.Linear(dim, dim)
)
def forward(self, feature_shape, timestep):
time_embed = self.time_embed(timestep)
processed_embed = self.processing_model(time_embed)
output_noise_tensor = processed_embed.view(-1, self.dim, 1, 1) # [batch, dim, 1, 1]
output_noise_tensor = output_noise_tensor.expand(-1, -1, feature_shape[2], feature_shape[3])
return output_noise_tensor
OK so, stuff we have to specify: feature shape and timestep
class UNetModel(nn.Module):
def init(self): # define all the crap
super().init()
self.en1 = ConvBlock(3, 64, "encoder")
self.en2 = ConvBlock(64, 128, "encoder")
self.en3 = ConvBlock(128, 256, "encoder")
self.en4 = ConvBlock(256, 512, "encoder")
self.en5 = ConvBlock(512, 1024, "encoder")
self.bottleneck = ConvBlock(1024, 1024, "bottleneck")
self.de1 = ConvBlock(1024 + 1024, 512, "decoder")
self.de2 = ConvBlock(512 + 512, 256, "decoder")
self.de3 = ConvBlock(256 + 256, 128, "decoder")
self.de4 = ConvBlock(128 + 128, 64, "decoder")
self.de5 = ConvBlock(64 + 64, 3, "decoder")
self.final_conv = nn.Conv2d(3, 3, kernel_size=1) # Kernel size one bc that's what your supposed to do
def forward(self, image, timestep): # do all the crap
en1_out = self.en1(image, timestep)
en2_out = self.en2(en1_out, timestep)
en3_out = self.en3(en2_out, timestep)
en4_out = self.en4(en3_out, timestep)
en5_out = self.en5(en4_out, timestep)
bottleneck_out = self.bottleneck(en5_out)
de1_out = self.de1(torch.cat([bottleneck_out, en5_out], dim=1), timestep) # skip connection stuff :3
de2_out = self.de2(torch.cat([de1_out, en4_out], dim=1), timestep)
de3_out = self.de3(torch.cat([de2_out, en3_out], dim=1), timestep)
de4_out = self.de4(torch.cat([de3_out, en2_out], dim=1), timestep)
de5_out = self.de5(torch.cat([de4_out, en1_out], dim=1), timestep)
output = self.final_conv(de5_out)
from torchinfo import summary
import torch
import math
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNetModel().to(device)
batch_size = 1
dummy_image = torch.randn(batch_size, 3, 512, 512, device=device)
dummy_timestep = torch.randint(0, 1000, (batch_size, 1), device=device).float()
summary(model, input_data=(dummy_image, dummy_timestep), device=device)
Here is the error message:
RuntimeError Traceback (most recent call last)
/usr/local/lib/python3.11/dist-packages/torchinfo/torchinfo.py in forward_pass(model, x, batch_dim, cache_forward_pass, device, mode, **kwargs)
294 if isinstance(x, (list, tuple)):
--> 295 _ = model(*x, **kwargs)
296 elif isinstance(x, dict):
10 frames
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1738 else:
-> 1739 return self._call_impl(*args, **kwargs)
1740
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1844 try:
-> 1845 return inner()
1846 except Exception:
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in inner()
1792
-> 1793 result = forward_call(*args, **kwargs)
1794 if _global_forward_hooks or self._forward_hooks:
in forward(self, image, timestep)
95 def forward(self, image, timestep): # do all the crap
---> 96 en1_out = self.en1(image, timestep)
97 en2_out = self.en2(en1_out, timestep)
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1738 else:
-> 1739 return self._call_impl(*args, **kwargs)
1740
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1844 try:
-> 1845 return inner()
1846 except Exception:
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in inner()
1792
-> 1793 result = forward_call(*args, **kwargs)
1794 if _global_forward_hooks or self._forward_hooks:
in forward(self, image, timestep)
26 image = self.relu(image)
---> 27 image += self.noise_injection(image.shape, timestep)
28
RuntimeError: The size of tensor a (64) must match the size of tensor b (256) at non-singleton dimension 1
The above exception was the direct cause of the following exception:
RuntimeError Traceback (most recent call last)
in <cell line: 0>()
23 # Use torchinfo.summary with input_data
24 # Pass the dummy tensors as a tuple to input_data
---> 25 summary(model, input_data=(dummy_image, dummy_timestep), device=device)
/usr/local/lib/python3.11/dist-packages/torchinfo/torchinfo.py in summary(model, input_size, input_data, batch_dim, cache_forward_pass, col_names, col_width, depth, device, dtypes, mode, row_settings, verbose, **kwargs)
221 input_data, input_size, batch_dim, device, dtypes
222 )
--> 223 summary_list = forward_pass(
224 model, x, batch_dim, cache_forward_pass, device, model_mode, **kwargs
225 )
/usr/local/lib/python3.11/dist-packages/torchinfo/torchinfo.py in forward_pass(model, x, batch_dim, cache_forward_pass, device, mode, **kwargs)
302 except Exception as e:
303 executed_layers = [layer for layer in summary_list if layer.executed]
--> 304 raise RuntimeError(
305 "Failed to run torchinfo. See above stack traces for more details. "
306 f"Executed layers up to: {executed_layers}"
RuntimeError: Failed to run torchinfo. See above stack traces for more details. Executed layers up to: [Conv2d: 2, BatchNorm2d: 2, ReLU: 2, Conv2d: 2, BatchNorm2d: 2, ReLU: 2, GenerateNoiseTensor: 2, SinusoidalEmbedding: 3, Sequential: 3, Linear: 4, ReLU: 4, Linear: 4]