@@ -122,16 +122,16 @@ def train(opt, device):
122
122
for p in model .parameters ():
123
123
p .requires_grad = True # for training
124
124
model = model .to (device )
125
- names = trainloader .dataset .classes # class names
126
- model .names = names # attach class names
127
125
128
126
# Info
129
127
if RANK in {- 1 , 0 }:
128
+ model .names = trainloader .dataset .classes # attach class names
129
+ model .transforms = testloader .dataset .torch_transforms # attach inference transforms
130
130
model_info (model )
131
131
if opt .verbose :
132
132
LOGGER .info (model )
133
133
images , labels = next (iter (trainloader ))
134
- file = imshow_cls (images [:25 ], labels [:25 ], names = names , f = save_dir / 'train_images.jpg' )
134
+ file = imshow_cls (images [:25 ], labels [:25 ], names = model . names , f = save_dir / 'train_images.jpg' )
135
135
logger .log_images (file , name = 'Train Examples' )
136
136
logger .log_graph (model , imgsz ) # log model
137
137
@@ -254,8 +254,8 @@ def train(opt, device):
254
254
255
255
# Plot examples
256
256
images , labels = (x [:25 ] for x in next (iter (testloader ))) # first 25 images and labels
257
- pred = torch .max (ema .ema (( images . half () if cuda else images . float ()) .to (device )), 1 )[1 ]
258
- file = imshow_cls (images , labels , pred , names , verbose = False , f = save_dir / 'test_images.jpg' )
257
+ pred = torch .max (ema .ema (images .to (device )), 1 )[1 ]
258
+ file = imshow_cls (images , labels , pred , model . names , verbose = False , f = save_dir / 'test_images.jpg' )
259
259
260
260
# Log results
261
261
meta = {"epochs" : epochs , "top1_acc" : best_fitness , "date" : datetime .now ().isoformat ()}
0 commit comments