1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15-
1615from collections import OrderedDict
1716
1817from paddle .distributed .fleet .model import PipelineParallel
@@ -46,6 +45,25 @@ def get_index_layer_func():
4645 return _GLOBAL_INDEX_LAYER_FUNC
4746
4847
48+ _GLOBAL_SNAME_TO_TNAME_FUNC = None
49+
50+
51+ def register_sname_to_tname_func (func ):
52+ global _GLOBAL_SNAME_TO_TNAME_FUNC
53+ _GLOBAL_SNAME_TO_TNAME_FUNC = func
54+
55+
56+ def has_register_sname_to_tname_func ():
57+ global _GLOBAL_SNAME_TO_TNAME_FUNC
58+ return _GLOBAL_SNAME_TO_TNAME_FUNC is not None
59+
60+
61+ def get_sname_to_tname_func ():
62+ global _GLOBAL_SNAME_TO_TNAME_FUNC
63+ assert _GLOBAL_SNAME_TO_TNAME_FUNC is not None , "sname to tname func is not registered yet"
64+ return _GLOBAL_SNAME_TO_TNAME_FUNC
65+
66+
4967class LayerNameScope :
5068 """
5169 layer name scope for a layer, layer name of the same kind of layer will be named consecutively
@@ -206,6 +224,7 @@ def __init__(self):
206224 self ._segments = OrderedDict ()
207225 self ._layer_to_segment = OrderedDict ()
208226 self ._param_to_tname = OrderedDict ()
227+ self ._wname_to_rname = OrderedDict ()
209228
210229 def add_segment (self , start_index , end_index ):
211230 segment = PipeLineSegment (start_index , end_index )
@@ -218,19 +237,24 @@ def add_layer(self, layer_index, layer_name, param_names):
218237 segment = self ._layer_to_segment [layer_index ]
219238 segment .add_layer (layer_name , param_names )
220239
221- def build_name_mapping (self ):
240+ def build_name_mapping (self , sname_to_tname = None ):
222241 for (k , segment ) in self ._segments .items ():
223242 for (i , layer ) in segment .layers .items ():
224243 for param in layer .params .items ():
225244 (param_name , tensor_name ) = param
226245 # map to a new name
227246 n_name = self ._rename_mgr .get_new_param_name (layer .name , tensor_name )
247+ if sname_to_tname is not None :
248+ if param_name in sname_to_tname .keys ():
249+ self ._wname_to_rname [param_name ] = sname_to_tname [param_name ]
228250 # logger.info(f"{param_name} {tensor_name}=>{n_name}")
229251 self ._param_to_tname [param_name ] = (tensor_name , n_name )
230252
231253 def map_name (self , param_name , t_name ):
232254 assert param_name in self ._param_to_tname
233255 tensor_name , n_name = self ._param_to_tname [param_name ]
256+ if param_name in self ._wname_to_rname :
257+ n_name = self ._wname_to_rname [param_name ]
234258 assert tensor_name == t_name
235259 return n_name
236260
@@ -261,6 +285,11 @@ def __init__(
261285 self ._index_layers ()
262286
263287 stage_segments = self ._segment ()
288+ if has_register_sname_to_tname_func ():
289+ self ._sname_to_tname = get_sname_to_tname_func ()(pp_model )
290+ else :
291+ self ._sname_to_tname = None
292+
264293 for (i , stage_seg ) in enumerate (stage_segments ):
265294 pipe_stage = PipeLineStage ()
266295 self ._stages .append (pipe_stage )
@@ -275,7 +304,7 @@ def __init__(
275304 self ._layer_name_to_stage [layer_name ] = i
276305
277306 for stage in self ._stages :
278- stage .build_name_mapping ()
307+ stage .build_name_mapping (self . _sname_to_tname )
279308
280309 def _index_layers (self ):
281310 for layer_name in self ._param_names_by_layer .keys ():
0 commit comments