@@ -121,9 +121,6 @@ def __init__(self, name_scope=None, dtype="float32"):
121121 self ._forward_pre_hooks = collections .OrderedDict ()
122122 self ._forward_post_hooks = collections .OrderedDict ()
123123
124- self ._parameters_transform_map = {}
125- self ._buffers_transform_map = {}
126-
127124 self ._casted_by_pure_fp16 = False
128125
129126 self ._state_dict_hooks = collections .OrderedDict ()
@@ -1473,24 +1470,14 @@ def _apply(self, func, device, dtype, blocking):
14731470 if param is not None :
14741471 with no_grad ():
14751472 param_applied = func (param , device , dtype , blocking )
1476- assert param .is_leaf
1477- param_applied .stop_gradient = param .stop_gradient
1478- self ._parameters [key ] = param_applied
14791473
14801474 if param .grad is not None :
14811475 with no_grad ():
14821476 grad_applied = func (param ._grad_ivar (), device , dtype ,
14831477 blocking )
14841478
1485- grad_applied .stop_gradient = param ._grad_ivar (
1486- ).stop_gradient
1487- self ._parameters [key ]._set_grad_ivar (grad_applied )
1488-
1489- self ._parameters_transform_map [id (param )] = [param_applied , key ]
1490-
14911479 for key , buf in self ._buffers .items ():
14921480 self ._buffers [key ] = func (buf , device , dtype , blocking )
1493- self ._buffers_transform_map [id (buf )] = [self ._buffers [key ], key ]
14941481
14951482 def to (self , device = None , dtype = None , blocking = None ):
14961483 '''
@@ -1568,22 +1555,59 @@ def transform(t, device, dtype, blocking):
15681555 if dtype is None :
15691556 dtype = t .dtype
15701557
1571- new_t = t ._copy_to (device , blocking )
1572- if isinstance (t , framework .ParamBase ):
1573- if dtype is not None and dtype != t .dtype :
1558+ # 1. gpu place need to determine whether the memory is sufficient for allocation:
1559+ if t .place .is_gpu_place ():
1560+ gpu_memory_available = core .gpu_memory_available ()
1561+ # for gpu, minimum memory allocation unit is 256 bytes.
1562+ if type (dtype ) is str :
1563+ size_dtype = core .size_of_dtype (
1564+ convert_np_dtype_to_dtype_ (dtype ))
1565+ else :
1566+ size_dtype = core .size_of_dtype (dtype )
1567+ # Note(zhangbo): Paddle GPU minimum memory allocation unit is 256 bytes, waiting_alloc_memory will comput ‘t’ occupied memory space.
1568+ # Coefficient 1.2 is used to avoid OOM that may occur in this critical state when the memory is just enough.
1569+ waiting_alloc_memory = (
1570+ (t .numel ().numpy ()[0 ] * size_dtype ) / 256 + 1 ) * 256 * 1.2
1571+ if gpu_memory_available < waiting_alloc_memory :
1572+ # Copy param / Tensor to cpu
1573+ t_used = t ._copy_to (paddle .CPUPlace (),
1574+ blocking ) # k-v type will error
1575+ # Release mem of t
1576+ t .value ().get_tensor ()._clear ()
1577+ else :
1578+ t_used = t
1579+ else :
1580+ t_used = t
1581+
1582+ # 2. cast param / Tensor to dtype
1583+ if dtype is not None and dtype != t_used .dtype :
1584+ if isinstance (t_used , framework .ParamBase ):
1585+ from paddle .fluid .layer_helper import LayerHelper
1586+ helper = LayerHelper ("cast" , ** locals ())
1587+ t_casted = helper .create_variable_for_type_inference (
1588+ dtype = dtype )
15741589 framework ._dygraph_tracer ().trace_op (
15751590 type = 'cast' ,
1576- inputs = {'X' : new_t },
1577- outputs = {'Out' : new_t },
1591+ inputs = {'X' : t_used },
1592+ outputs = {'Out' : t_casted },
15781593 attrs = {
1579- 'in_dtype' : t .dtype ,
1594+ 'in_dtype' : t_used .dtype ,
15801595 'out_dtype' : convert_np_dtype_to_dtype_ (dtype )
15811596 })
1597+ else :
1598+ t_casted = t_used .cast (dtype = dtype )
15821599 else :
1583- if dtype is not None and dtype != t .dtype :
1584- new_t = new_t .cast (dtype = dtype )
1600+ t_casted = t_used
1601+
1602+ # 3. Copy casted cpu param / Tensor to device
1603+ new_t = t_casted ._copy_to (device , blocking )
1604+
1605+ # 4. share Tensor to origin param / Tensor
1606+ dst_tensor = t .value ().get_tensor ()
1607+ src_tensor = new_t .value ().get_tensor ()
1608+ dst_tensor ._share_data_with (src_tensor )
15851609
1586- return new_t
1610+ return t
15871611
15881612 with warnings .catch_warnings ():
15891613 warnings .filterwarnings ("ignore" , category = UserWarning )
0 commit comments