@@ -64,6 +64,7 @@ def build_comp_desc_from_op(op):
6464 var = get_var_with_recursion (var_name , op .block , op .block .program )
6565 shape = var .shape
6666 var_desc .append ((var .dtype , shape ))
67+ desc ["dtype" ] = var .dtype
6768 output_desc [out_name ] = var_desc
6869 desc ["outputs" ] = output_desc
6970
@@ -166,6 +167,7 @@ def build_comp_desc_from_dist_op(dist_op, dist_context):
166167 shard_sizes ,
167168 )
168169 var_desc .append ((var .dtype , shape ))
170+ desc ["dtype" ] = var .dtype
169171
170172 # For special op such as fill_constant_batch_size_like
171173 if op .type == "fill_constant_batch_size_like" :
@@ -377,7 +379,9 @@ def build_comp_costs_from_descs(op_cost_class, ctx, processes, descs, cluster):
377379 """Build comp costs by descriptions."""
378380 costs = {}
379381 for process in processes :
380- costs [process ] = op_cost_class (op_desc = descs [process ], cluster = cluster )
382+ costs [process ] = op_cost_class (
383+ op_desc = descs [process ], cluster = cluster , rank = process
384+ )
381385 return costs
382386
383387
@@ -464,8 +468,9 @@ def build_dp_costs(
464468 desc ["inputs" ]["X" ] = [(var .dtype , shape )]
465469 attrs = {"scale" : 1.0 / dp_degree }
466470 desc ["attrs" ] = attrs
471+ desc ["dtype" ] = var .dtype
467472 scale_op_cost = _g_op_cost_factory ["scale" ](
468- op_desc = desc , cluster = cluster
473+ op_desc = desc , cluster = cluster , rank = rank
469474 )
470475 scale_costs [rank ] = scale_op_cost
471476 result .append (scale_costs )
@@ -862,11 +867,12 @@ def _check_comm_op_type(cls):
862867class CompOpCost (OpCost ):
863868 OP_TYPE = "COMP"
864869
865- def __init__ (self , op = None , op_desc = None , cluster = None ):
870+ def __init__ (self , op = None , op_desc = None , cluster = None , rank = None ):
866871 super ().__init__ (op = op , op_desc = op_desc )
867872 self ._check_comp_op_type ()
868- self ._cost = self .calc_cost ()
869873 self .cluster = cluster
874+ self .rank = rank
875+ self ._cost = self .calc_cost ()
870876
871877 @classmethod
872878 def _check_comp_op_type (cls ):
@@ -876,6 +882,17 @@ def _check_comp_op_type(cls):
876882 f"Please Check op type not in { NON_COMP_TYPE } , but got { cls .OP_TYPE } ."
877883 )
878884
885+ def get_rank_gflops (self , rank , dtype ):
886+ device = self .cluster .get_device (rank )
887+ gflops = 7800
888+ if dtype == paddle .float64 :
889+ gflops = device .dp_gflops
890+ elif dtype == paddle .float32 :
891+ gflops = device .sp_gflops
892+ elif dtype == paddle .float16 or dtype == paddle .bfloat16 :
893+ gflops = device .hp_gflops
894+ return gflops
895+
879896 def calc_flops (self ):
880897 if not self .op_desc :
881898 return 0
@@ -891,8 +908,15 @@ def calc_flops(self):
891908 )
892909
893910 def calc_time (self ):
911+ if self .rank is None or self .op_desc is None :
912+ device_gflops = 7800
913+ else :
914+ device_gflops = self .get_rank_gflops (
915+ self .rank , self .op_desc ["dtype" ]
916+ )
894917 flops_count = self .calc_flops ()
895- return flops_count * 2.9e-7
918+ utilization_rate = 0.65
919+ return flops_count / (utilization_rate * device_gflops ) * 1e-3
896920
897921
898922def register_op_cost (cls ):
0 commit comments