@@ -94,21 +94,22 @@ def attempt_load(weights, map_location=None, inplace=True, fuse=True):
94
94
model = Ensemble ()
95
95
for w in weights if isinstance (weights , list ) else [weights ]:
96
96
ckpt = torch .load (attempt_download (w ), map_location = map_location ) # load
97
- if fuse :
98
- model .append (ckpt ['ema' if ckpt .get ('ema' ) else 'model' ].float ().fuse ().eval ()) # FP32 model
99
- else :
100
- model .append (ckpt ['ema' if ckpt .get ('ema' ) else 'model' ].float ().eval ()) # without layer fuse
97
+ ckpt = (ckpt ['ema' ] or ckpt ['model' ]).float () # FP32 model
98
+ model .append (ckpt .fuse ().eval () if fuse else ckpt .eval ()) # fused or un-fused model in eval mode
101
99
102
100
# Compatibility updates
103
101
for m in model .modules ():
104
- if type (m ) in [nn .Hardswish , nn .LeakyReLU , nn .ReLU , nn .ReLU6 , nn .SiLU , Detect , Model ]:
105
- m .inplace = inplace # pytorch 1.7.0 compatibility
106
- if type (m ) is Detect :
102
+ t = type (m )
103
+ if t in (nn .Hardswish , nn .LeakyReLU , nn .ReLU , nn .ReLU6 , nn .SiLU , Detect , Model ):
104
+ m .inplace = inplace # torch 1.7.0 compatibility
105
+ if t is Detect :
107
106
if not isinstance (m .anchor_grid , list ): # new Detect Layer compatibility
108
107
delattr (m , 'anchor_grid' )
109
108
setattr (m , 'anchor_grid' , [torch .zeros (1 )] * m .nl )
110
- elif type (m ) is Conv :
111
- m ._non_persistent_buffers_set = set () # pytorch 1.6.0 compatibility
109
+ elif t is nn .Upsample :
110
+ m .recompute_scale_factor = None # torch 1.11.0 compatibility
111
+ elif t is Conv :
112
+ m ._non_persistent_buffers_set = set () # torch 1.6.0 compatibility
112
113
113
114
if len (model ) == 1 :
114
115
return model [- 1 ] # return model
0 commit comments