Skip to content

Commit c0d3f80

Browse files
Add verbose option to pytorch hub models (#2926)
* Add verbose and update print to logging * Fix positonal param * Revert auto formatting changes * Update hubconf.py Co-authored-by: Glenn Jocher <[email protected]>
1 parent 3665c0f commit c0d3f80

File tree

3 files changed

+41
-38
lines changed

3 files changed

+41
-38
lines changed

hubconf.py

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,9 @@
1616

1717
dependencies = ['torch', 'yaml']
1818
check_requirements(Path(__file__).parent / 'requirements.txt', exclude=('pycocotools', 'thop'))
19-
set_logging()
2019

2120

22-
def create(name, pretrained, channels, classes, autoshape):
21+
def create(name, pretrained, channels, classes, autoshape, verbose):
2322
"""Creates a specified YOLOv5 model
2423
2524
Arguments:
@@ -32,6 +31,8 @@ def create(name, pretrained, channels, classes, autoshape):
3231
pytorch model
3332
"""
3433
try:
34+
set_logging(verbose=verbose)
35+
3536
cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0] # model.yaml path
3637
model = Model(cfg, channels, classes)
3738
if pretrained:
@@ -55,7 +56,7 @@ def create(name, pretrained, channels, classes, autoshape):
5556
raise Exception(s) from e
5657

5758

58-
def custom(path_or_model='path/to/model.pt', autoshape=True):
59+
def custom(path_or_model='path/to/model.pt', autoshape=True, verbose=True):
5960
"""YOLOv5-custom model https://github.com/ultralytics/yolov5
6061
6162
Arguments (3 options):
@@ -66,6 +67,8 @@ def custom(path_or_model='path/to/model.pt', autoshape=True):
6667
Returns:
6768
pytorch model
6869
"""
70+
set_logging(verbose=verbose)
71+
6972
model = torch.load(path_or_model) if isinstance(path_or_model, str) else path_or_model # load checkpoint
7073
if isinstance(model, dict):
7174
model = model['ema' if model.get('ema') else 'model'] # load model
@@ -79,49 +82,49 @@ def custom(path_or_model='path/to/model.pt', autoshape=True):
7982
return hub_model.to(device)
8083

8184

82-
def yolov5s(pretrained=True, channels=3, classes=80, autoshape=True):
85+
def yolov5s(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True):
8386
# YOLOv5-small model https://github.com/ultralytics/yolov5
84-
return create('yolov5s', pretrained, channels, classes, autoshape)
87+
return create('yolov5s', pretrained, channels, classes, autoshape, verbose)
8588

8689

87-
def yolov5m(pretrained=True, channels=3, classes=80, autoshape=True):
90+
def yolov5m(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True):
8891
# YOLOv5-medium model https://github.com/ultralytics/yolov5
89-
return create('yolov5m', pretrained, channels, classes, autoshape)
92+
return create('yolov5m', pretrained, channels, classes, autoshape, verbose)
9093

9194

92-
def yolov5l(pretrained=True, channels=3, classes=80, autoshape=True):
95+
def yolov5l(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True):
9396
# YOLOv5-large model https://github.com/ultralytics/yolov5
94-
return create('yolov5l', pretrained, channels, classes, autoshape)
97+
return create('yolov5l', pretrained, channels, classes, autoshape, verbose)
9598

9699

97-
def yolov5x(pretrained=True, channels=3, classes=80, autoshape=True):
100+
def yolov5x(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True):
98101
# YOLOv5-xlarge model https://github.com/ultralytics/yolov5
99-
return create('yolov5x', pretrained, channels, classes, autoshape)
102+
return create('yolov5x', pretrained, channels, classes, autoshape, verbose)
100103

101104

102-
def yolov5s6(pretrained=True, channels=3, classes=80, autoshape=True):
103-
# YOLOv5-small model https://github.com/ultralytics/yolov5
104-
return create('yolov5s6', pretrained, channels, classes, autoshape)
105+
def yolov5s6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True):
106+
# YOLOv5-small-P6 model https://github.com/ultralytics/yolov5
107+
return create('yolov5s6', pretrained, channels, classes, autoshape, verbose)
105108

106109

107-
def yolov5m6(pretrained=True, channels=3, classes=80, autoshape=True):
108-
# YOLOv5-medium model https://github.com/ultralytics/yolov5
109-
return create('yolov5m6', pretrained, channels, classes, autoshape)
110+
def yolov5m6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True):
111+
# YOLOv5-medium-P6 model https://github.com/ultralytics/yolov5
112+
return create('yolov5m6', pretrained, channels, classes, autoshape, verbose)
110113

111114

112-
def yolov5l6(pretrained=True, channels=3, classes=80, autoshape=True):
113-
# YOLOv5-large model https://github.com/ultralytics/yolov5
114-
return create('yolov5l6', pretrained, channels, classes, autoshape)
115+
def yolov5l6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True):
116+
# YOLOv5-large-P6 model https://github.com/ultralytics/yolov5
117+
return create('yolov5l6', pretrained, channels, classes, autoshape, verbose)
115118

116119

117-
def yolov5x6(pretrained=True, channels=3, classes=80, autoshape=True):
118-
# YOLOv5-xlarge model https://github.com/ultralytics/yolov5
119-
return create('yolov5x6', pretrained, channels, classes, autoshape)
120+
def yolov5x6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True):
121+
# YOLOv5-xlarge-P6 model https://github.com/ultralytics/yolov5
122+
return create('yolov5x6', pretrained, channels, classes, autoshape, verbose)
120123

121124

122125
if __name__ == '__main__':
123-
model = create(name='yolov5s', pretrained=True, channels=3, classes=80, autoshape=True) # pretrained example
124-
# model = custom(path_or_model='path/to/model.pt') # custom example
126+
model = create(name='yolov5s', pretrained=True, channels=3, classes=80, autoshape=True, verbose=True) # pretrained
127+
# model = custom(path_or_model='path/to/model.pt') # custom
125128

126129
# Verify inference
127130
import cv2

models/yolo.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, i
8484
self.yaml['anchors'] = round(anchors) # override yaml value
8585
self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
8686
self.names = [str(i) for i in range(self.yaml['nc'])] # default names
87-
# print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
87+
# logger.info([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
8888

8989
# Build strides, anchors
9090
m = self.model[-1] # Detect()
@@ -95,7 +95,7 @@ def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, i
9595
check_anchor_order(m)
9696
self.stride = m.stride
9797
self._initialize_biases() # only run once
98-
# print('Strides: %s' % m.stride.tolist())
98+
# logger.info('Strides: %s' % m.stride.tolist())
9999

100100
# Init weights, biases
101101
initialize_weights(self)
@@ -134,13 +134,13 @@ def forward_once(self, x, profile=False):
134134
for _ in range(10):
135135
_ = m(x)
136136
dt.append((time_synchronized() - t) * 100)
137-
print('%10.1f%10.0f%10.1fms %-40s' % (o, m.np, dt[-1], m.type))
137+
logger.info('%10.1f%10.0f%10.1fms %-40s' % (o, m.np, dt[-1], m.type))
138138

139139
x = m(x) # run
140140
y.append(x if m.i in self.save else None) # save output
141141

142142
if profile:
143-
print('%.1fms total' % sum(dt))
143+
logger.info('%.1fms total' % sum(dt))
144144
return x
145145

146146
def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
@@ -157,15 +157,15 @@ def _print_biases(self):
157157
m = self.model[-1] # Detect() module
158158
for mi in m.m: # from
159159
b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85)
160-
print(('%6g Conv2d.bias:' + '%10.3g' * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean()))
160+
logger.info(('%6g Conv2d.bias:' + '%10.3g' * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean()))
161161

162162
# def _print_weights(self):
163163
# for m in self.model.modules():
164164
# if type(m) is Bottleneck:
165-
# print('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights
165+
# logger.info('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights
166166

167167
def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
168-
print('Fusing layers... ')
168+
logger.info('Fusing layers... ')
169169
for m in self.model.modules():
170170
if type(m) is Conv and hasattr(m, 'bn'):
171171
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
@@ -177,19 +177,19 @@ def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
177177
def nms(self, mode=True): # add or remove NMS module
178178
present = type(self.model[-1]) is NMS # last layer is NMS
179179
if mode and not present:
180-
print('Adding NMS... ')
180+
logger.info('Adding NMS... ')
181181
m = NMS() # module
182182
m.f = -1 # from
183183
m.i = self.model[-1].i + 1 # index
184184
self.model.add_module(name='%s' % m.i, module=m) # add
185185
self.eval()
186186
elif not mode and present:
187-
print('Removing NMS... ')
187+
logger.info('Removing NMS... ')
188188
self.model = self.model[:-1] # remove
189189
return self
190190

191191
def autoshape(self): # add autoShape module
192-
print('Adding autoShape... ')
192+
logger.info('Adding autoShape... ')
193193
m = autoShape(self) # wrap model
194194
copy_attr(m, self, include=('yaml', 'nc', 'hyp', 'names', 'stride'), exclude=()) # copy attributes
195195
return m
@@ -272,6 +272,6 @@ def parse_model(d, ch): # model_dict, input_channels(3)
272272
# Tensorboard (not working https://github.com/ultralytics/yolov5/issues/2898)
273273
# from torch.utils.tensorboard import SummaryWriter
274274
# tb_writer = SummaryWriter('.')
275-
# print("Run 'tensorboard --logdir=models' to view tensorboard at http://localhost:6006/")
275+
# logger.info("Run 'tensorboard --logdir=models' to view tensorboard at http://localhost:6006/")
276276
# tb_writer.add_graph(torch.jit.trace(model, img, strict=False), []) # add model graph
277277
# tb_writer.add_image('test', img[0], dataformats='CWH') # add model to tensorboard

utils/general.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@
3232
os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads
3333

3434

35-
def set_logging(rank=-1):
35+
def set_logging(rank=-1, verbose=True):
3636
logging.basicConfig(
3737
format="%(message)s",
38-
level=logging.INFO if rank in [-1, 0] else logging.WARN)
38+
level=logging.INFO if (verbose and rank in [-1, 0]) else logging.WARN)
3939

4040

4141
def init_seeds(seed=0):

0 commit comments

Comments
 (0)