@@ -20,11 +20,12 @@ class Architect():
2020    """" Architect controls architecture of cell by computing gradients of alphas 
2121    """ 
2222
23-     def  __init__ (self , model , w_momentum , w_weight_decay ):
23+     def  __init__ (self , model , w_momentum , w_weight_decay ,  device ):
2424        self .model  =  model 
2525        self .v_model  =  copy .deepcopy (model )
2626        self .w_momentum  =  w_momentum 
2727        self .w_weight_decay  =  w_weight_decay 
28+         self .device  =  device 
2829
2930    def  virtual_step (self , train_x , train_y , xi , w_optim ):
3031        """ 
@@ -43,17 +44,21 @@ def virtual_step(self, train_x, train_y, xi, w_optim):
4344        # Forward and calculate loss 
4445        # Loss for train with w. L_train(w) 
4546        loss  =  self .model .loss (train_x , train_y )
47+ 
4648        # Compute gradient 
4749        gradients  =  torch .autograd .grad (loss , self .model .getWeights ())
48- 
50+          
4951        # Do virtual step (Update gradient) 
5052        # Below operations do not need gradient tracking 
5153        with  torch .no_grad ():
5254            # dict key is not the value, but the pointer. So original network weight have to 
5355            # be iterated also. 
5456            for  w , vw , g  in  zip (self .model .getWeights (), self .v_model .getWeights (), gradients ):
5557                m  =  w_optim .state [w ].get ("momentum_buffer" , 0. ) *  self .w_momentum 
56-                 vw .copy_ (w  -  torch .FloatTensor (xi ) *  (m  +  g  +  self .w_weight_decay  *  w ))
58+                 if (self .device  ==  'cuda' ):
59+                     vw .copy_ (w  -  torch .cuda .FloatTensor (xi ) *  (m  +  g  +  self .w_weight_decay  *  w ))
60+                 elif (self .device  ==  'cpu' ):
61+                     vw .copy_ (w  -  torch .FloatTensor (xi ) *  (m  +  g  +  self .w_weight_decay  *  w ))
5762
5863            # Sync alphas 
5964            for  a , va  in  zip (self .model .getAlphas (), self .v_model .getAlphas ()):
@@ -71,7 +76,7 @@ def unrolled_backward(self, train_x, train_y, valid_x, valid_y, xi, w_optim):
7176        # Calculate unrolled loss 
7277        # Loss for validation with w'. L_valid(w') 
7378        loss  =  self .v_model .loss (valid_x , valid_y )
74- 
79+          
7580        # Calculate gradient 
7681        v_alphas  =  tuple (self .v_model .getAlphas ())
7782        v_weights  =  tuple (self .v_model .getWeights ())
@@ -85,7 +90,10 @@ def unrolled_backward(self, train_x, train_y, valid_x, valid_y, xi, w_optim):
8590        # Update final gradient = dalpha - xi * hessian 
8691        with  torch .no_grad ():
8792            for  alpha , da , h  in  zip (self .model .getAlphas (), dalpha , hessian ):
88-                 alpha .grad  =  da  -  torch .FloatTensor (xi ) *  h 
93+                 if (self .device  ==  'cuda' ):
94+                     alpha .grad  =  da  -  torch .cuda .FloatTensor (xi ) *  h 
95+                 elif (self .device  ==  'cpu' ):
96+                     alpha .grad  =  da  -  torch .cpu .FloatTensor (xi ) *  h 
8997
9098    def  compute_hessian (self , dws , train_x , train_y ):
9199        """ 
0 commit comments