@@ -92,21 +92,36 @@ classification task of your choice.)
9292 )
9393 return model
9494
95- def get_dataloader (batch_size = 256 , num_workers = 8 , split = ' train' ):
96-
97- transforms = torchvision.transforms.Compose(
98- [torchvision.transforms.RandomHorizontalFlip(),
99- torchvision.transforms.RandomAffine(0 ),
100- torchvision.transforms.ToTensor(),
101- torchvision.transforms.Normalize((0.4914 , 0.4822 , 0.4465 ), (0.2023 , 0.1994 , 0.201 ))])
102-
95+ def get_dataloader (batch_size = 256 , num_workers = 8 , split = ' train' , shuffle = False , augment = True ):
96+ if augment:
97+ transforms = torchvision.transforms.Compose(
98+ [torchvision.transforms.RandomHorizontalFlip(),
99+ torchvision.transforms.RandomAffine(0 ),
100+ torchvision.transforms.ToTensor(),
101+ torchvision.transforms.Normalize((0.4914 , 0.4822 , 0.4465 ),
102+ (0.2023 , 0.1994 , 0.201 ))])
103+ else :
104+ transforms = torchvision.transforms.Compose([
105+ torchvision.transforms.ToTensor(),
106+ torchvision.transforms.Normalize((0.4914 , 0.4822 , 0.4465 ),
107+ (0.2023 , 0.1994 , 0.201 ))])
108+
103109 is_train = (split == ' train' )
104- dataset = torchvision.datasets.CIFAR10(root = ' /tmp/cifar/' , download = True , train = is_train, transform = transforms)
105- loader = torch.utils.data.DataLoader(dataset = dataset, shuffle = False , batch_size = batch_size, num_workers = num_workers)
106-
110+ dataset = torchvision.datasets.CIFAR10(root = ' /tmp/cifar/' ,
111+ download = True ,
112+ train = is_train,
113+ transform = transforms)
114+
115+ loader = torch.utils.data.DataLoader(dataset = dataset,
116+ shuffle = shuffle,
117+ batch_size = batch_size,
118+ num_workers = num_workers)
119+
107120 return loader
108121
109- def train (model , loader , lr = 0.4 , epochs = 24 , momentum = 0.9 , weight_decay = 5e-4 , lr_peak_epoch = 5 , label_smoothing = 0.0 ):
122+ def train (model , loader , lr = 0.4 , epochs = 24 , momentum = 0.9 ,
123+ weight_decay = 5e-4 , lr_peak_epoch = 5 , label_smoothing = 0.0 , model_id = 0 ):
124+
110125 opt = SGD(model.parameters(), lr = lr, momentum = momentum, weight_decay = weight_decay)
111126 iters_per_epoch = len (loader)
112127 # Cyclic LR with single triangle
@@ -118,9 +133,8 @@ classification task of your choice.)
118133 loss_fn = CrossEntropyLoss(label_smoothing = label_smoothing)
119134
120135 for ep in range (epochs):
121- model_count = 0
122136 for it, (ims, labs) in enumerate (loader):
123- ims = ims.float(). cuda()
137+ ims = ims.cuda()
124138 labs = labs.cuda()
125139 opt.zero_grad(set_to_none = True )
126140 with autocast():
@@ -131,15 +145,19 @@ classification task of your choice.)
131145 scaler.step(opt)
132146 scaler.update()
133147 scheduler.step()
148+ if ep in [12 , 15 , 18 , 21 , 23 ]:
149+ torch.save(model.state_dict(), f ' ./checkpoints/sd_ { model_id} _epoch_ { ep} .pt ' )
150+
151+ return model
134152
135153 os.makedirs(' ./checkpoints' , exist_ok = True )
154+ loader_for_training = get_dataloader(batch_size = 512 , split = ' train' , shuffle = True )
136155
137- for i in tqdm(range (3 ), desc = ' Training models..' ):
156+ # you can modify the for loop below to train more models
157+ for i in tqdm(range (1 ), desc = ' Training models..' ):
138158 model = construct_rn9().to(memory_format = torch.channels_last).cuda()
139- loader_train = get_dataloader(batch_size = 512 , split = ' train' )
140- train(model, loader_train)
159+ model = train(model, loader_for_training, model_id = i)
141160
142- torch.save(model.state_dict(), f ' ./checkpoints/sd_ { i} .pt ' )
143161
144162 .. raw :: html
145163
@@ -311,4 +329,4 @@ The final line above returns :code:`TRAK` scores as a :code:`numpy.array` from t
311329
312330That's it!
313331Once you have your model(s) and your data, just a few API-calls to TRAK
314- let's you compute data attribution scores.
332+ let's you compute data attribution scores.
0 commit comments