33from collections import namedtuple
44import enum
55
6+ import torch .nn as nn
7+ import torch .nn .functional as F
8+
69from modules import sd_models , cache , errors , hashes , shared
710
811NetworkWeights = namedtuple ('NetworkWeights' , ['network_key' , 'sd_key' , 'w' , 'sd_module' ])
@@ -115,6 +118,29 @@ def __init__(self, net: Network, weights: NetworkWeights):
115118 if hasattr (self .sd_module , 'weight' ):
116119 self .shape = self .sd_module .weight .shape
117120
121+ self .ops = None
122+ self .extra_kwargs = {}
123+ if isinstance (self .sd_module , nn .Conv2d ):
124+ self .ops = F .conv2d
125+ self .extra_kwargs = {
126+ 'stride' : self .sd_module .stride ,
127+ 'padding' : self .sd_module .padding
128+ }
129+ elif isinstance (self .sd_module , nn .Linear ):
130+ self .ops = F .linear
131+ elif isinstance (self .sd_module , nn .LayerNorm ):
132+ self .ops = F .layer_norm
133+ self .extra_kwargs = {
134+ 'normalized_shape' : self .sd_module .normalized_shape ,
135+ 'eps' : self .sd_module .eps
136+ }
137+ elif isinstance (self .sd_module , nn .GroupNorm ):
138+ self .ops = F .group_norm
139+ self .extra_kwargs = {
140+ 'num_groups' : self .sd_module .num_groups ,
141+ 'eps' : self .sd_module .eps
142+ }
143+
118144 self .dim = None
119145 self .bias = weights .w .get ("bias" )
120146 self .alpha = weights .w ["alpha" ].item () if "alpha" in weights .w else None
@@ -137,7 +163,7 @@ def calc_scale(self):
137163 def finalize_updown (self , updown , orig_weight , output_shape , ex_bias = None ):
138164 if self .bias is not None :
139165 updown = updown .reshape (self .bias .shape )
140- updown += self .bias .to (orig_weight .device , dtype = orig_weight .dtype )
166+ updown += self .bias .to (orig_weight .device , dtype = updown .dtype )
141167 updown = updown .reshape (output_shape )
142168
143169 if len (output_shape ) == 4 :
@@ -155,5 +181,10 @@ def calc_updown(self, target):
155181 raise NotImplementedError ()
156182
157183 def forward (self , x , y ):
158- raise NotImplementedError ()
184+ """A general forward implementation for all modules"""
185+ if self .ops is None :
186+ raise NotImplementedError ()
187+ else :
188+ updown , ex_bias = self .calc_updown (self .sd_module .weight )
189+ return y + self .ops (x , weight = updown , bias = ex_bias , ** self .extra_kwargs )
159190
0 commit comments