Skip to content

Commit bab1629

Browse files
Diferswentaoyu
authored andcommitted
update rule_based_tuner for hetero training. (PaddlePaddle#63280)
* add mesh_group adjust cost model fix code style * fix ci: remove print fix ci: ut error * fix topo multi machine run remove debug log * update rule_based_tuner with new cluster * fix some type * add tune o3 for hetero in pp * run tune o3 when hetero case * fix some bug --------- Co-authored-by: hitywt <[email protected]>
1 parent 03688cf commit bab1629

File tree

12 files changed

+1110
-200
lines changed

12 files changed

+1110
-200
lines changed

python/paddle/distributed/auto_parallel/static/cluster.py

Lines changed: 517 additions & 13 deletions
Large diffs are not rendered by default.

python/paddle/distributed/auto_parallel/static/cost/base_cost.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def build_comp_desc_from_op(op):
6464
var = get_var_with_recursion(var_name, op.block, op.block.program)
6565
shape = var.shape
6666
var_desc.append((var.dtype, shape))
67+
desc["dtype"] = var.dtype
6768
output_desc[out_name] = var_desc
6869
desc["outputs"] = output_desc
6970

@@ -166,6 +167,7 @@ def build_comp_desc_from_dist_op(dist_op, dist_context):
166167
shard_sizes,
167168
)
168169
var_desc.append((var.dtype, shape))
170+
desc["dtype"] = var.dtype
169171

170172
# For special op such as fill_constant_batch_size_like
171173
if op.type == "fill_constant_batch_size_like":
@@ -377,7 +379,9 @@ def build_comp_costs_from_descs(op_cost_class, ctx, processes, descs, cluster):
377379
"""Build comp costs by descriptions."""
378380
costs = {}
379381
for process in processes:
380-
costs[process] = op_cost_class(op_desc=descs[process], cluster=cluster)
382+
costs[process] = op_cost_class(
383+
op_desc=descs[process], cluster=cluster, rank=process
384+
)
381385
return costs
382386

383387

@@ -464,8 +468,9 @@ def build_dp_costs(
464468
desc["inputs"]["X"] = [(var.dtype, shape)]
465469
attrs = {"scale": 1.0 / dp_degree}
466470
desc["attrs"] = attrs
471+
desc["dtype"] = var.dtype
467472
scale_op_cost = _g_op_cost_factory["scale"](
468-
op_desc=desc, cluster=cluster
473+
op_desc=desc, cluster=cluster, rank=rank
469474
)
470475
scale_costs[rank] = scale_op_cost
471476
result.append(scale_costs)
@@ -862,11 +867,12 @@ def _check_comm_op_type(cls):
862867
class CompOpCost(OpCost):
863868
OP_TYPE = "COMP"
864869

865-
def __init__(self, op=None, op_desc=None, cluster=None):
870+
def __init__(self, op=None, op_desc=None, cluster=None, rank=None):
866871
super().__init__(op=op, op_desc=op_desc)
867872
self._check_comp_op_type()
868-
self._cost = self.calc_cost()
869873
self.cluster = cluster
874+
self.rank = rank
875+
self._cost = self.calc_cost()
870876

871877
@classmethod
872878
def _check_comp_op_type(cls):
@@ -876,6 +882,17 @@ def _check_comp_op_type(cls):
876882
f"Please Check op type not in {NON_COMP_TYPE}, but got {cls.OP_TYPE}."
877883
)
878884

885+
def get_rank_gflops(self, rank, dtype):
886+
device = self.cluster.get_device(rank)
887+
gflops = 7800
888+
if dtype == paddle.float64:
889+
gflops = device.dp_gflops
890+
elif dtype == paddle.float32:
891+
gflops = device.sp_gflops
892+
elif dtype == paddle.float16 or dtype == paddle.bfloat16:
893+
gflops = device.hp_gflops
894+
return gflops
895+
879896
def calc_flops(self):
880897
if not self.op_desc:
881898
return 0
@@ -891,8 +908,15 @@ def calc_flops(self):
891908
)
892909

893910
def calc_time(self):
911+
if self.rank is None or self.op_desc is None:
912+
device_gflops = 7800
913+
else:
914+
device_gflops = self.get_rank_gflops(
915+
self.rank, self.op_desc["dtype"]
916+
)
894917
flops_count = self.calc_flops()
895-
return flops_count * 2.9e-7
918+
utilization_rate = 0.65
919+
return flops_count / (utilization_rate * device_gflops) * 1e-3
896920

897921

898922
def register_op_cost(cls):

0 commit comments

Comments
 (0)