@@ -121,9 +121,6 @@ def __init__(self, name_scope=None, dtype="float32"):
121121        self ._forward_pre_hooks  =  collections .OrderedDict ()
122122        self ._forward_post_hooks  =  collections .OrderedDict ()
123123
124-         self ._parameters_transform_map  =  {}
125-         self ._buffers_transform_map  =  {}
126- 
127124        self ._casted_by_pure_fp16  =  False 
128125
129126        self ._state_dict_hooks  =  collections .OrderedDict ()
@@ -1473,24 +1470,14 @@ def _apply(self, func, device, dtype, blocking):
14731470            if  param  is  not None :
14741471                with  no_grad ():
14751472                    param_applied  =  func (param , device , dtype , blocking )
1476-                     assert  param .is_leaf 
1477-                     param_applied .stop_gradient  =  param .stop_gradient 
1478-                     self ._parameters [key ] =  param_applied 
14791473
14801474                if  param .grad  is  not None :
14811475                    with  no_grad ():
14821476                        grad_applied  =  func (param ._grad_ivar (), device , dtype ,
14831477                                            blocking )
14841478
1485-                         grad_applied .stop_gradient  =  param ._grad_ivar (
1486-                         ).stop_gradient 
1487-                         self ._parameters [key ]._set_grad_ivar (grad_applied )
1488- 
1489-             self ._parameters_transform_map [id (param )] =  [param_applied , key ]
1490- 
14911479        for  key , buf  in  self ._buffers .items ():
14921480            self ._buffers [key ] =  func (buf , device , dtype , blocking )
1493-             self ._buffers_transform_map [id (buf )] =  [self ._buffers [key ], key ]
14941481
14951482    def  to (self , device = None , dtype = None , blocking = None ):
14961483        ''' 
@@ -1568,24 +1555,54 @@ def transform(t, device, dtype, blocking):
15681555            if  dtype  is  None :
15691556                dtype  =  t .dtype 
15701557
1571-             new_t  =  t ._copy_to (device , blocking )
1572-             if  isinstance (t , framework .ParamBase ):
1573-                 if  dtype  is  not None  and  dtype  !=  t .dtype :
1574-                     framework ._dygraph_tracer ().trace_op (
1575-                         type = 'cast' ,
1576-                         inputs = {'X' : new_t },
1577-                         outputs = {'Out' : new_t },
1578-                         attrs = {
1579-                             'in_dtype' : t .dtype ,
1580-                             'out_dtype' : convert_np_dtype_to_dtype_ (dtype )
1581-                         })
1558+             if  type (dtype ) is  str :
1559+                 dtype  =  convert_np_dtype_to_dtype_ (dtype )
1560+ 
1561+             # 1. gpu place need to determine whether the memory is sufficient for allocation: 
1562+             if  t .place .is_gpu_place ():
1563+                 # for gpu, minimum memory allocation unit is 256 bytes. 
1564+                 size_dtype  =  core .size_of_dtype (dtype )
1565+                 # Note(zhangbo): Paddle GPU minimum memory allocation unit is 256 bytes, waiting_alloc_memory will comput ‘t’ occupied memory space. 
1566+                 # Coefficient 1.2 is used to avoid OOM that may occur in this critical state when the memory is just enough. 
1567+                 waiting_alloc_memory  =  (
1568+                     (np .prod (t .shape ) *  size_dtype ) /  256  +  1 ) *  256  *  1.2 
1569+                 gpu_memory_available  =  core .gpu_memory_available ()
1570+                 if  gpu_memory_available  <  waiting_alloc_memory :
1571+                     # Copy param / Tensor to cpu 
1572+                     t_used  =  t ._copy_to (paddle .CPUPlace (),
1573+                                         blocking )  # k-v type will error 
1574+                     # Release mem of t 
1575+                     t .value ().get_tensor ()._clear ()
1576+                 else :
1577+                     t_used  =  t 
1578+             else :
1579+                 t_used  =  t 
1580+ 
1581+             # 2. cast param / Tensor to dtype 
1582+             if  dtype  is  not None  and  dtype  !=  t_used .dtype :
1583+                 with  paddle .fluid .framework ._dygraph_place_guard (
1584+                         place = t_used .place ):
1585+                     t_casted  =  t_used .cast (dtype = dtype )
1586+             else :
1587+                 t_casted  =  t_used 
1588+ 
1589+             # 3. Copy casted cpu param / Tensor to device 
1590+             if  device  is  not None  and  not  t_casted .place ._equals (device ):
1591+                 new_t  =  t_casted ._copy_to (device , blocking )
15821592            else :
1583-                 if  dtype  is  not None  and  dtype  !=  t .dtype :
1584-                     new_t  =  new_t .cast (dtype = dtype )
1593+                 new_t  =  t_casted 
1594+ 
1595+             # 4. share Tensor to origin param / Tensor 
1596+             dst_tensor  =  t .value ().get_tensor ()
1597+             src_tensor  =  new_t .value ().get_tensor ()
1598+             dst_tensor ._share_data_with (src_tensor )
1599+ 
1600+             return  t 
15851601
1586-             return  new_t 
1602+         with  warnings .catch_warnings ():
1603+             warnings .filterwarnings ("ignore" , category = UserWarning )
1604+             self ._apply (transform , device , dtype , blocking )
15871605
1588-         self ._apply (transform , device , dtype , blocking )
15891606        self ._dtype  =  dtype 
15901607
15911608    # [aliases] Compatible with old method names 
0 commit comments