Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 75 additions & 20 deletions paddle/cinn/hlir/framework/pir/trivial_op_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "paddle/cinn/hlir/framework/pir/trivial_op_util.h"

#include "paddle/cinn/common/dim_expr_converter.h"
#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
#include "paddle/cinn/hlir/framework/compile_error.h"
#include "paddle/cinn/hlir/framework/pir/op_lowering_util.h"
Expand Down Expand Up @@ -547,9 +548,6 @@ ExprTransformer RemoveVarInScheduleBlockRealize(const ir::Var& target_vars,
* remove it in axes.bind()
*/
const auto& f = [=](const ir::Expr& e) -> ir::Expr {
VLOG(4) << "Start RemoveVarInScheduleBlockRealize(" << target_vars << ", "
<< replaced_expr << ")";
VLOG(4) << " Input is " << e;
PADDLE_ENFORCE_NE(
e.As<ir::ScheduleBlockRealize>(),
nullptr,
Expand All @@ -562,22 +560,11 @@ ExprTransformer RemoveVarInScheduleBlockRealize(const ir::Var& target_vars,
auto block_bound_vars = copied_ir.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->iter_vars;
for (const auto& i_var : schedule_block_iter_vars) {
PADDLE_ENFORCE_EQ(
i_var.is_var(),
true,
::common::errors::InvalidArgument("RemoveVarInScheduleBlockRealize: "
"axes.bind rhs is is not a Var."));
}
// find replace idx
int target_idx = -1;
for (int i = 0; i < schedule_block_iter_vars.size(); ++i) {
VLOG(4) << "RemoveVarInScheduleBlockRealize: compare with "
<< schedule_block_iter_vars[i] << " vs " << target_vars
<< ", and equality is: "
<< (schedule_block_iter_vars[i].as_var()->name ==
target_vars->name);
if (schedule_block_iter_vars[i].as_var()->name == target_vars->name) {
if (schedule_block_iter_vars[i].is_var() &&
schedule_block_iter_vars[i].as_var()->name == target_vars->name) {
target_idx = i;
}
}
Expand Down Expand Up @@ -688,8 +675,6 @@ ExprTransformer RemoveOneTransformer(int one) {
.GetSingle(copied);
const ir::Expr& target_block =
ExprSetFinderUtils::DirectlyFather(copied).GetSingle(target_for);
VLOG(4) << "RemoveOneTransformer: directly target_block of for is "
<< target_block;
if (target_block.As<ir::ScheduleBlockRealize>() != nullptr) {
VLOG(4) << "RemoveOneTransformer: father block is root realize";
ir::Expr shedule_block =
Expand All @@ -708,7 +693,6 @@ ExprTransformer RemoveOneTransformer(int one) {
shedule_block.As<ir::ScheduleBlock>()->body = for_body;
}
} else if (target_block.As<ir::Block>() != nullptr) {
VLOG(4) << "RemoveOneTransformer: father block is Block";
std::vector<ir::Expr> new_bodies;
for (const auto& expr : target_block.As<ir::Block>()->stmts) {
if (expr != target_for) {
Expand All @@ -728,7 +712,6 @@ ExprTransformer RemoveOneTransformer(int one) {
"RemoveOneTransformer: target for father should be a ir::Block or "
"ir::ScheduleBlockRealize."));
}
VLOG(4) << "Remove Var to 0 in ScheduleBlockRealizer: " << copied;
// Remove var to 0 in ScheduleBlockRealizer
InplaceMutateSingleExpr(
&copied,
Expand Down Expand Up @@ -949,6 +932,10 @@ std::vector<ir::Var> GetAllLoopVars(const ir::Expr& root) {

ir::Expr GetBodyBlock(const ir::Expr& root) {
const auto& iters = GetNonReduceLoopVars(root);
if (iters.empty()) {
return ir::Block::Make(
{ExprSetFinderUtils::ChildScheduleBlockRealizes.GetSingle(root)});
}
const size_t reduce_size =
std::count_if(iters.begin(), iters.end(), [](const ir::Var& v) {
return v->is_reduce_axis;
Expand All @@ -965,6 +952,74 @@ ir::Expr GetBodyBlock(const ir::Expr& root) {
->body;
}

ir::Expr ReshapeLoop(const ir::Expr& root,
const std::vector<symbol::DimExpr>& in_shape,
const std::vector<symbol::DimExpr>& out_shape) {
auto copied = ir::ir_utils::IRCopy(root);

ir::ModuleExpr mod_expr({copied});
ir::IRSchedule ir_sch(
mod_expr, -1, false, cinn::utils::ErrorMessageLevel::kGeneral, true);

const auto block_realize =
(ExprSetFinderUtils::ChildScheduleBlockRealizes).GetSingle(copied);
const auto block_name = block_realize.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->name;
const auto shape_partion = fusion::PartionReshapeAxes(in_shape, out_shape);

for (int idx = shape_partion.size() - 1; idx > 0; --idx) {
const auto& in_s = shape_partion[idx - 1].first;
const auto& in_e = shape_partion[idx].first;
const auto& out_s = shape_partion[idx - 1].second;
const auto& out_e = shape_partion[idx].second;

std::vector<int> fuse_indices;
for (int i = in_e - 1; i >= in_s; --i) {
if (in_shape[i] != symbol::DimExpr(1)) {
fuse_indices.insert(fuse_indices.begin(), i);
} else {
VLOG(4) << "Remove index[" << i << "]: " << in_shape[i]
<< " for expr: \n"
<< copied;
copied = ExprTransformerUtils::RemoveOneTransformer(i)(copied);
ir_sch.SetExprs({copied});
for (auto& index : fuse_indices) {
index--;
}
}
}
if (fuse_indices.size() > 1) {
VLOG(4) << "fuse_indices: " << cinn::utils::Join(fuse_indices, ",");
ir_sch.Fuse(block_name, fuse_indices);
}

std::vector<ir::Expr> split_shapes;
for (int i = out_s; i < out_e; ++i) {
if (out_shape[i] != symbol::DimExpr(1)) {
split_shapes.push_back(
cinn::common::DimExprConverter().ConvertToIrExpr(out_shape[i]));
}
}
if (split_shapes.size() > 1) {
ir_sch.Split(ir_sch.GetLoops(block_name)[in_s], split_shapes)[0];
}
}

std::vector<int> insert_axis;
std::vector<ir::Var> ones_var;
for (int i = 0; i < out_shape.size(); ++i) {
if (out_shape[i] == symbol::DimExpr(1)) {
insert_axis.push_back(i);
ones_var.push_back(ir::Var(1, "one_" + std::to_string(ones_var.size())));
}
}
copied = ExprTransformerUtils::InsertForsTransformer(insert_axis,
ones_var)(copied);

return copied;
}

} // namespace trivial_fusion_detail
} // namespace pir
} // namespace framework
Expand Down
4 changes: 4 additions & 0 deletions paddle/cinn/hlir/framework/pir/trivial_op_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,10 @@ std::vector<ir::Var> GetAllLoopVars(const ir::Expr& root);

ir::Expr GetBodyBlock(const ir::Expr& root);

ir::Expr ReshapeLoop(const ir::Expr& root,
const std::vector<symbol::DimExpr>& in_shape,
const std::vector<symbol::DimExpr>& out_shape);

} // namespace trivial_fusion_detail
} // namespace pir
} // namespace framework
Expand Down
4 changes: 2 additions & 2 deletions paddle/cinn/ir/group_schedule/config/group_tile_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ bool GetCanApplyGridReduce(const std::vector<ir::Expr>& op_compute_bodies,
auto* block = expr_block.As<ir::ScheduleBlockRealize>();
auto& iter_vars = block->schedule_block.As<ir::ScheduleBlock>()->iter_vars;
for (int i = 0; i < iter_vars.size(); i++) {
ir::Var loop_var = block->iter_values[i].as_var_ref();
if (reduce_loop_vars.count(loop_var->name) > 0) {
if (block->iter_values[i].is_var() &&
reduce_loop_vars.count(block->iter_values[i].as_var()->name) > 0) {
reduce_iter_vars.insert(iter_vars[i]->name);
}
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/operator_fusion/fusion_tracker/expr_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ ir::Expr ApplyItersTransform::operator()(const TransposeItersTransform& trans) {

ir::Expr ApplyItersTransform::operator()(const RemoveOnesTransform& trans) {
VLOG(4) << "[ItersTransform] Before RemoveOnesTransform("
<< utils::Join(trans.ones_, ",") << "'): " << expr_;
<< utils::Join(trans.ones_, ",") << "): " << expr_;
auto result = RemoveOnesTransformer(trans.ones_)(expr_);
VLOG(4) << "[ItersTransform] After RemoveOnesTransform: " << result;
return result;
Expand Down
18 changes: 18 additions & 0 deletions paddle/cinn/operator_fusion/fusion_tracker/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,20 @@ void RunItersTransformInstr(const std::shared_ptr<ItersTransformInstr>& instr,
interpreter->scope[instr->target_] = new_pattern;
}

void RunReshapeAlignInstr(const std::shared_ptr<ReshapeAlignInstr>& instr,
FusionInterpreter* interpreter) {
const auto expr = std::visit(
FusibleOp2Expr(), interpreter->scope[instr->input_]->fusion_ops[0])[0];
VLOG(4) << "Before RunReshapeAlignInstr: \n" << expr;
auto result = cinn::hlir::framework::pir::trivial_fusion_detail::ReshapeLoop(
expr, instr->in_shape_, instr->out_shape_);

auto new_pattern = std::make_shared<ScopeElement>();
new_pattern->fusion_ops.emplace_back(TrivialOp(result));
interpreter->scope[instr->result_] = new_pattern;
VLOG(4) << "After ReshapeAlignInstr: \n" << result;
}

void RunPaddingInstr(const std::shared_ptr<PaddingInstr>& instr,
FusionInterpreter* interpreter) {
ScopeElementPtr new_pattern = std::make_shared<ScopeElement>();
Expand Down Expand Up @@ -229,6 +243,10 @@ std::vector<ir::Expr> FusionInterpreter::Run() {
RunItersTransformInstr(
dynamic_cast_instr_with_err<ItersTransformInstr>(instr), this);
break;
case T_ReshapeAlign:
RunReshapeAlignInstr(
dynamic_cast_instr_with_err<ReshapeAlignInstr>(instr), this);
break;
default:
PADDLE_THROW(
::common::errors::Unavailable("Unsupported Fusion Instrution"));
Expand Down
27 changes: 27 additions & 0 deletions paddle/cinn/operator_fusion/fusion_tracker/tracker.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ enum InstructionType {
T_Return,
T_InitPattern,
T_TrivialInline,
T_ReshapeAlign,
T_TmpTransform,
T_TrivialLoopAlign,
T_ItersTransform,
Expand Down Expand Up @@ -143,6 +144,32 @@ struct TrivialInlineInstr : public FusionInstruction {
}
};

struct ReshapeAlignInstr : public FusionInstruction {
ReshapeAlignInstr(const std::string& input,
const std::vector<symbol::DimExpr>& in_shape,
const std::vector<symbol::DimExpr>& out_shape,
const std::string& result)
: input_(input),
in_shape_(in_shape),
out_shape_(out_shape),
result_(result) {}
virtual InstructionType type() const { return T_ReshapeAlign; }
virtual FusionInstrPtr Clone() {
return std::make_shared<ReshapeAlignInstr>(*this);
}

std::string input_;
std::vector<symbol::DimExpr> in_shape_;
std::vector<symbol::DimExpr> out_shape_;
std::string result_;

virtual std::string DebugStr() const {
return "ReshapeAlignInstr || " + input_ + "(" +
cinn::utils::Join(in_shape_, ",") + ") => " + result_ + "(" +
cinn::utils::Join(out_shape_, ",") + ")";
}
};

struct TmpTransformInstr : public FusionInstruction {
TmpTransformInstr(const std::string& upstream,
const std::string& downstream,
Expand Down
108 changes: 85 additions & 23 deletions paddle/cinn/operator_fusion/graph_transformer/matcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,39 @@ struct AlwaysTrue {
}
};

struct NonSinkNodeMatcher {
bool operator()(const PatternGraph& graph, const PatternNodePtr& node) {
return !node->downstream().empty();
}
};

struct IsOutputNodeMatcher {
bool operator()(const PatternGraph& graph, const PatternNodePtr& node) {
bool res = IsAnyFirstInSecond(node->sink_op()->results(), graph.outputs());
return res;
}
};

template <int N>
struct DownstreamSmallerThan {
bool operator()(const PatternGraph& graph, const PatternNodePtr& node) {
return node->downstream().size() < N;
}
};

template <int N>
struct DownstreamGreaterThan {
bool operator()(const PatternGraph& graph, const PatternNodePtr& node) {
return node->downstream().size() > N;
}
};

struct OnlyOneDownstreamMatcher {
bool operator()(const PatternGraph& graph, const PatternNodePtr& node) {
return node->downstream().size() == 1;
}
};

template <typename StmtPattern>
struct StmtPatternGraphMatcher {
bool operator()(const PatternGraph& graph, const PatternNodePtr& node) {
Expand Down Expand Up @@ -139,7 +172,15 @@ struct RecomputeNodeMatcher {
if (node->fusion_iters().output_values.size() > 1) {
return false;
}

bool has_combine_fusion =
std::any_of(node->fusion_tracker()->instructions_.begin(),
node->fusion_tracker()->instructions_.end(),
[](const FusionInstrPtr& instr) {
return instr->type() == T_Combine;
});
if (has_combine_fusion) {
return false;
}
for (const auto& op : GetOpsInPattern(node->stmt_pattern())) {
const auto& op_kind = GetOpPatternKind(op);
if (op_kind >= hlir::framework::kReduction) {
Expand Down Expand Up @@ -183,9 +224,50 @@ struct TransposeOpMatcher {
}
};

struct NonSinkNodeMatcher {
struct ReshapeOpMatcher {
bool operator()(const PatternGraph& graph, const PatternNodePtr& node) {
return !node->downstream().empty();
return (node->sink_op()->name() == "cinn_op.reshape");
}
};

struct ReshapeConnectionMatcher {
bool operator()(const PatternGraph& graph, const PatternNodePtr& node) {
bool upstream_match =
node->downstream().size() == 1 &&
node->downstream()[0]->sink_op()->name() == "cinn_op.reshape" &&
node->downstream()[0]->downstream().size() == 1;
bool downstream_match = node->sink_op()->name() == "cinn_op.reshape" &&
node->downstream().size() == 1;
return upstream_match || downstream_match;
}
};

struct LeafReshapeConnectionMatcher {
bool operator()(const PatternGraph& graph, const PatternNodePtr& node) {
const auto match_upstream = [&graph](const PatternNodePtr& upstream) {
return StmtPatternGraphMatcher<TrivialPattern>()(graph, upstream) &&
upstream->downstream().size() == 1 &&
!upstream->upstream().empty() &&
std::any_of(upstream->upstream().begin(),
upstream->upstream().end(),
[&graph](const PatternNodePtr& node) {
return DownstreamGreaterThan<1>()(graph, node);
});
};
const auto match_downstream = [&graph](const PatternNodePtr& downstream) {
return downstream->sink_op()->name() == "cinn_op.reshape" &&
downstream->downstream().size() == 1 &&
downstream->downstream()[0]->downstream().empty() &&
downstream->fusion_iters().loop_iters ==
downstream->downstream()[0]->fusion_iters().loop_iters;
};
bool upstream_match = match_upstream(node) &&
node->downstream().size() == 1 &&
match_downstream(node->downstream()[0]);
bool downstream_match = match_downstream(node) &&
node->upstream().size() == 1 &&
match_upstream(node->upstream()[0]);
return upstream_match || downstream_match;
}
};

Expand All @@ -206,26 +288,6 @@ struct NotAllElementWiseDownstreamMatcher {
}
};

struct IsOutputNodeMatcher {
bool operator()(const PatternGraph& graph, const PatternNodePtr& node) {
bool res = IsAnyFirstInSecond(node->sink_op()->results(), graph.outputs());
return res;
}
};

template <int N>
struct DownstreamSmallerThan {
bool operator()(const PatternGraph& graph, const PatternNodePtr& node) {
return node->downstream().size() < N;
}
};

template <int N>
struct DownstreamGreaterThan {
bool operator()(const PatternGraph& graph, const PatternNodePtr& node) {
return node->downstream().size() > N;
}
};
template <typename... Args>
struct And {};

Expand Down
Loading