Skip to content

Commit 61406a5

Browse files
authored
Fix tensor devices for DARTS Trial (#2273)
* Update architect.py [email protected] Signed-off-by: Chen Pin-Han <72907153+sifa1024​@users.noreply.github.com> * Update run_trial.py [email protected] Signed-off-by: Chen Pin-Han <72907153+sifa1024​@users.noreply.github.com> * Update architect.py [email protected] Signed-off-by: Chen Pin-Han <72907153+sifa1024​@users.noreply.github.com> --------- Signed-off-by: Chen Pin-Han <72907153+sifa1024​@users.noreply.github.com>
1 parent a2f3fca commit 61406a5

File tree

2 files changed

+14
-6
lines changed

2 files changed

+14
-6
lines changed

examples/v1beta1/trial-images/darts-cnn-cifar10/architect.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,12 @@ class Architect():
2020
"""" Architect controls architecture of cell by computing gradients of alphas
2121
"""
2222

23-
def __init__(self, model, w_momentum, w_weight_decay):
23+
def __init__(self, model, w_momentum, w_weight_decay, device):
2424
self.model = model
2525
self.v_model = copy.deepcopy(model)
2626
self.w_momentum = w_momentum
2727
self.w_weight_decay = w_weight_decay
28+
self.device = device
2829

2930
def virtual_step(self, train_x, train_y, xi, w_optim):
3031
"""
@@ -43,17 +44,21 @@ def virtual_step(self, train_x, train_y, xi, w_optim):
4344
# Forward and calculate loss
4445
# Loss for train with w. L_train(w)
4546
loss = self.model.loss(train_x, train_y)
47+
4648
# Compute gradient
4749
gradients = torch.autograd.grad(loss, self.model.getWeights())
48-
50+
4951
# Do virtual step (Update gradient)
5052
# Below operations do not need gradient tracking
5153
with torch.no_grad():
5254
# dict key is not the value, but the pointer. So original network weight have to
5355
# be iterated also.
5456
for w, vw, g in zip(self.model.getWeights(), self.v_model.getWeights(), gradients):
5557
m = w_optim.state[w].get("momentum_buffer", 0.) * self.w_momentum
56-
vw.copy_(w - torch.FloatTensor(xi) * (m + g + self.w_weight_decay * w))
58+
if(self.device == 'cuda'):
59+
vw.copy_(w - torch.cuda.FloatTensor(xi) * (m + g + self.w_weight_decay * w))
60+
elif(self.device == 'cpu'):
61+
vw.copy_(w - torch.FloatTensor(xi) * (m + g + self.w_weight_decay * w))
5762

5863
# Sync alphas
5964
for a, va in zip(self.model.getAlphas(), self.v_model.getAlphas()):
@@ -71,7 +76,7 @@ def unrolled_backward(self, train_x, train_y, valid_x, valid_y, xi, w_optim):
7176
# Calculate unrolled loss
7277
# Loss for validation with w'. L_valid(w')
7378
loss = self.v_model.loss(valid_x, valid_y)
74-
79+
7580
# Calculate gradient
7681
v_alphas = tuple(self.v_model.getAlphas())
7782
v_weights = tuple(self.v_model.getWeights())
@@ -85,7 +90,10 @@ def unrolled_backward(self, train_x, train_y, valid_x, valid_y, xi, w_optim):
8590
# Update final gradient = dalpha - xi * hessian
8691
with torch.no_grad():
8792
for alpha, da, h in zip(self.model.getAlphas(), dalpha, hessian):
88-
alpha.grad = da - torch.FloatTensor(xi) * h
93+
if(self.device == 'cuda'):
94+
alpha.grad = da - torch.cuda.FloatTensor(xi) * h
95+
elif(self.device == 'cpu'):
96+
alpha.grad = da - torch.cpu.FloatTensor(xi) * h
8997

9098
def compute_hessian(self, dws, train_x, train_y):
9199
"""

examples/v1beta1/trial-images/darts-cnn-cifar10/run_trial.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def main():
140140
num_epochs,
141141
eta_min=w_lr_min)
142142

143-
architect = Architect(model, w_momentum, w_weight_decay)
143+
architect = Architect(model, w_momentum, w_weight_decay, device)
144144

145145
# Start training
146146
best_top1 = 0.

0 commit comments

Comments
 (0)