Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 79 additions & 106 deletions python/paddle/decomposition/recompute.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import math
import os
import time
from collections import deque
from typing import TYPE_CHECKING

import paddle
Expand Down Expand Up @@ -174,8 +173,6 @@ def DebugPrint(*args):
class JudgeFusionLoop:
def __init__(self, program, unrecomputable_ops):
self.ops = program.global_block().ops
self.operand_value_set = set()
self.result_value_set = set()
self.unrecomputable_ops = unrecomputable_ops
self.downstream_unrecomputable_ops_map = {op: set() for op in self.ops}
self.upstream_unrecomputable_ops_map = {op: set() for op in self.ops}
Expand Down Expand Up @@ -214,7 +211,6 @@ def _get_producer_ops(op):
source_op = operand.get_defining_op()
if source_op.get_parent_block() == op.get_parent_block():
producers.add(source_op)
self.operand_value_set.add(operand)
return producers

def _get_consumer_ops(op):
Expand All @@ -223,41 +219,36 @@ def _get_consumer_ops(op):
for parent_op in result.all_used_ops_in_same_block():
if parent_op is not None:
consumers.add(parent_op)
self.result_value_set.add(result)
return consumers

def _get_producer_ops_recursively(root):
visited = set()
queue = deque()
queue.append(root)
visited.add(root)
while queue:
cur = queue.popleft()
self.downstream_unrecomputable_ops_map[cur].add(root)
for new_op in _get_producer_ops(cur):
if new_op in visited:
continue
visited.add(new_op)
queue.append(new_op)

def _get_consumer_ops_recursively(root):
visited = set()
queue = deque()
queue.append(root)
visited.add(root)
while queue:
cur = queue.popleft()
self.upstream_unrecomputable_ops_map[cur].add(root)
for new_op in _get_consumer_ops(cur):
if new_op in visited:
continue
visited.add(new_op)
queue.append(new_op)
def _get_upstream_ops_recursively(cur):
upstream_unrecomputable_ops = set()
for new_op in _get_producer_ops(cur):
upstream_unrecomputable_ops |= (
self.upstream_unrecomputable_ops_map[new_op]
)
if cur.name() in self.unrecomputable_ops:
upstream_unrecomputable_ops.add(cur)
return upstream_unrecomputable_ops

def _get_downstream_ops_recursively(cur):
downstream_unrecomputable_ops = set()
for new_op in _get_consumer_ops(cur):
downstream_unrecomputable_ops |= (
self.downstream_unrecomputable_ops_map[new_op]
)
if cur.name() in self.unrecomputable_ops:
downstream_unrecomputable_ops.add(cur)
return downstream_unrecomputable_ops

for op in self.ops:
if op.name() in self.unrecomputable_ops:
_get_producer_ops_recursively(op)
_get_consumer_ops_recursively(op)
self.upstream_unrecomputable_ops_map[
op
] |= _get_upstream_ops_recursively(op)
for op in reversed(self.ops):
self.downstream_unrecomputable_ops_map[
op
] |= _get_downstream_ops_recursively(op)

def _has_unfusible_op_on_any_path(self, op1, op2):
no_unfusible_op_on_path = (
Expand All @@ -278,11 +269,17 @@ def _has_unfusible_op_on_any_path(self, op1, op2):
else False
)

def _get_operand_value_set(self):
return backward_utils.ValueSet(self.operand_value_set)

def _get_result_value_set(self):
return backward_utils.ValueSet(self.result_value_set)
class Op2IdxMap:
def __init__(self, program):
self.op_to_idx_map = {}
for idx, op_iter in enumerate(program.global_block().ops):
self.op_to_idx_map[op_iter] = idx

def get_idx(self, op):
if self.op_to_idx_map.get(op, None):
return self.op_to_idx_map[op]
raise RuntimeError("op not found in program")


def auto_recompute(
Expand Down Expand Up @@ -736,12 +733,6 @@ def partition_joint_graph(
mem += cal_value_node_size(mid)
DebugPrint("Saved Memory is: ", mem / 1024 / 1024 / 1024, "GB")

def getIdx(program, op):
for idx, op_iter in enumerate(program.global_block().ops):
if op == op_iter:
return idx
raise RuntimeError("op not found in program")

# 2. Extract the recompute subgraph and replace forward mid hold values with recompute subgraph's outputs
program, fwd_op_end_idx = replace_mid_values_with_forward_subgraph(
program,
Expand All @@ -757,6 +748,7 @@ def getIdx(program, op):
def replace_mid_values_with_forward_subgraph(
program, saved_values, mid_values, fwd_op_end_idx, backward_op_start_idx
):

def _extract_forward_recompute_subgraph_for_backward(
saved_values, mid_values
):
Expand All @@ -767,12 +759,6 @@ def _find_recompute_ops(
needed_saved_values,
chain,
):
def getIdx(program, op):
for idx, op_iter in enumerate(program.global_block().ops):
if op == op_iter:
return idx
raise RuntimeError("op not found in program")

new_chain = list(chain)
new_chain.append(recompute_value)
define_op = recompute_value.get_defining_op()
Expand All @@ -790,13 +776,6 @@ def getIdx(program, op):
"pd_op.full",
"pd_op.full_int_array",
]:

def getIdx(program, op):
for idx, op_iter in enumerate(program.global_block().ops):
if op == op_iter:
return idx
raise RuntimeError("op not found in program")

raise Exception(
f"Every path to recompute value {recompute_value} must have saved value or starting point of the path is one of op in [pd_op.full, pd_op.full_int_array], but find {define_op.name()} op, op ir is {define_op}"
)
Expand All @@ -820,12 +799,6 @@ def getIdx(program, op):
recompute_subgraph_inputs = backward_utils.ValueSet()
recompute_subgraph_outputs_backward_needed = mid_values

def getIdx(program, op):
for idx, op_iter in enumerate(program.global_block().ops):
if op == op_iter:
return idx
raise RuntimeError("op not found in program")

for recompute_value in mid_values:
_find_recompute_ops(
recompute_value,
Expand All @@ -844,6 +817,8 @@ def getIdx(program, op):
}
return recompute_subgraph

op_2_id_map = Op2IdxMap(program)

forward_ops = set(program.global_block().ops[: fwd_op_end_idx + 1])
backward_ops = set(program.global_block().ops[backward_op_start_idx:])
first_backward_op = program.global_block().ops[backward_op_start_idx]
Expand All @@ -859,12 +834,13 @@ def getIdx(program, op):
origin_ops = recompute_forward_subgraph["recompute_ops"]
origin_subgraph_inputs = recompute_forward_subgraph["inputs"]
origin_subgraph_outputs = recompute_forward_subgraph["outputs"]
cloned_ops, value_map = clone_graph(
cloned_ops, value_map, cloned_op_first_grad_user_map = clone_graph(
program,
origin_ops,
origin_subgraph_inputs,
first_backward_op,
backward_ops,
op_2_id_map,
)

for origin_op in origin_ops:
Expand All @@ -880,17 +856,22 @@ def getIdx(program, op):
cloned_subgraph_outputs.add(cloned_value)

# 4. reset recomputed ops location in program
reseted_ops = set()
backward_ops_list = program.global_block().ops[backward_op_start_idx:]
for op in backward_ops_list:
op_inputs = op.operands_source()
for op_input in op_inputs:
if op_input in cloned_subgraph_outputs:
parent_ops = find_parent_ops(op_input)
for cloned_op in cloned_ops:
if cloned_op in parent_ops and cloned_op not in reseted_ops:
cloned_op.move_before(op)
reseted_ops.add(cloned_op)
for op in reversed(cloned_ops):
first_subgraph_grad_user = cloned_op_first_grad_user_map.get(op, None)
for op_outputs in op.results():
for child in op_outputs.all_used_ops_in_same_block():
if cloned_op_first_grad_user_map.get(child, 0):
if first_subgraph_grad_user is None or op_2_id_map.get_idx(
cloned_op_first_grad_user_map[child]
) < op_2_id_map.get_idx(first_subgraph_grad_user):
first_subgraph_grad_user = (
cloned_op_first_grad_user_map[child]
)
assert first_subgraph_grad_user is not None
cloned_op_first_grad_user_map[op] = first_subgraph_grad_user

for cloned_op in cloned_ops:
cloned_op.move_before(cloned_op_first_grad_user_map[cloned_op])
return program, fwd_op_end_idx


Expand Down Expand Up @@ -1115,28 +1096,42 @@ def analyze_mid_hold_values(
return mid_hold_values


def get_first_backward_use_op(fwd_op, backward_ops):
def get_first_backward_use_op(fwd_op, backward_ops, op_2_id_map):
first_backward_use_op = None
for user_op in fwd_op.results()[0].all_used_ops_in_same_block():
if user_op in backward_ops:
return user_op
if user_op in backward_ops and (
first_backward_use_op is None
or op_2_id_map.get_idx(user_op)
< op_2_id_map.get_idx(first_backward_use_op)
):
first_backward_use_op = user_op
return first_backward_use_op


def clone_graph(
program, origin_ops, graph_inputs, clone_insertion_op, backward_ops
program,
origin_ops,
graph_inputs,
clone_insertion_op,
backward_ops,
op_2_id_map,
):
pir.set_insertion_point(clone_insertion_op)
all_ops = program.global_block().ops
value_map = paddle.pir.IrMapping()
origin_ops = set(origin_ops)
cloned_ops = []
cloned_op_first_grad_user_map = {}
for input_value in graph_inputs:
value_map.add(input_value, input_value)
for op in all_ops:
if op in origin_ops:
new_op = op.clone(
value_map, paddle.pir.CloneOptions(False, True, True)
)
first_backward_use_op = get_first_backward_use_op(op, backward_ops)
first_backward_use_op = get_first_backward_use_op(
op, backward_ops, op_2_id_map
)
if (
first_backward_use_op is not None
and first_backward_use_op.has_attr('op_role')
Expand All @@ -1145,29 +1140,7 @@ def clone_graph(
new_op.set_int_attr("op_role", first_backward_use_op.op_role)
new_op.set_int_attr("chunk_id", first_backward_use_op.chunk_id)
cloned_ops.append(new_op)
if first_backward_use_op is not None:
cloned_op_first_grad_user_map[new_op] = first_backward_use_op
pir.set_insertion_point_to_block_end(program.global_block())
return cloned_ops, value_map


def find_parent_ops(value):
visited = backward_utils.ValueSet()

def _find_parent_ops(value):
parent_ops = set()
stack = [value]

while stack:
current = stack.pop()
if current in visited:
continue
visited.add(current)
parent_op = current.get_defining_op()
if parent_op is not None:
parent_ops.add(parent_op)
op_inputs = parent_op.operands_source()
for op_input in op_inputs:
if current.initialized():
stack.append(op_input)
return parent_ops

return _find_parent_ops(value)
return cloned_ops, value_map, cloned_op_first_grad_user_map
Loading