Skip to content

Commit 2ba2392

Browse files
glenn-jocherClay Januhowski
authored andcommitted
Attach transforms to model (ultralytics#9028)
* Attach transforms to model Signed-off-by: Glenn Jocher <[email protected]> * Update val.py Signed-off-by: Glenn Jocher <[email protected]> * Update train.py Signed-off-by: Glenn Jocher <[email protected]> Signed-off-by: Glenn Jocher <[email protected]>
1 parent 9c1878c commit 2ba2392

File tree

2 files changed

+6
-7
lines changed

2 files changed

+6
-7
lines changed

classify/train.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,16 +122,16 @@ def train(opt, device):
122122
for p in model.parameters():
123123
p.requires_grad = True # for training
124124
model = model.to(device)
125-
names = trainloader.dataset.classes # class names
126-
model.names = names # attach class names
127125

128126
# Info
129127
if RANK in {-1, 0}:
128+
model.names = trainloader.dataset.classes # attach class names
129+
model.transforms = testloader.dataset.torch_transforms # attach inference transforms
130130
model_info(model)
131131
if opt.verbose:
132132
LOGGER.info(model)
133133
images, labels = next(iter(trainloader))
134-
file = imshow_cls(images[:25], labels[:25], names=names, f=save_dir / 'train_images.jpg')
134+
file = imshow_cls(images[:25], labels[:25], names=model.names, f=save_dir / 'train_images.jpg')
135135
logger.log_images(file, name='Train Examples')
136136
logger.log_graph(model, imgsz) # log model
137137

@@ -254,8 +254,8 @@ def train(opt, device):
254254

255255
# Plot examples
256256
images, labels = (x[:25] for x in next(iter(testloader))) # first 25 images and labels
257-
pred = torch.max(ema.ema((images.half() if cuda else images.float()).to(device)), 1)[1]
258-
file = imshow_cls(images, labels, pred, names, verbose=False, f=save_dir / 'test_images.jpg')
257+
pred = torch.max(ema.ema(images.to(device)), 1)[1]
258+
file = imshow_cls(images, labels, pred, model.names, verbose=False, f=save_dir / 'test_images.jpg')
259259

260260
# Log results
261261
meta = {"epochs": epochs, "top1_acc": best_fitness, "date": datetime.now().isoformat()}

classify/val.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def run(
3939
project=ROOT / 'runs/val-cls', # save to project/name
4040
name='exp', # save to project/name
4141
exist_ok=False, # existing project/name ok, do not increment
42-
half=True, # use FP16 half-precision inference
42+
half=False, # use FP16 half-precision inference
4343
dnn=False, # use OpenCV DNN for ONNX inference
4444
model=None,
4545
dataloader=None,
@@ -124,7 +124,6 @@ def run(
124124
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms post-process per image at shape {shape}' % t)
125125
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
126126

127-
model.float() # for training
128127
return top1, top5, loss
129128

130129

0 commit comments

Comments
 (0)