Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@ set(cinn_transforms_deps
op_dialect_vjp
cinn_runtime_dialect
op_fusion
pir_compiler)
pir_compiler
json)

include_directories(cinn_transforms PRIVATE
${PADDLE_SOURCE_DIR}/third_party/nlohmann_json/include/)

cinn_cc_library(cinn_transforms SRCS ${cinn_transforms_srcs} DEPS
${cinn_transforms_deps})
Expand Down
13 changes: 13 additions & 0 deletions paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,13 @@
#include "paddle/cinn/hlir/dialect/operator/transforms/remove_assign_out_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/replace_dynamic_expand_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/shape_ops_fallback_to_phi_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/specify_input_dynamic_dim_util.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/split_generate_shape_into_shape_ops_pass.h"
#include "paddle/fluid/pir/transforms/build_cinn_pass.h"
#include "paddle/fluid/pir/transforms/general/dead_code_elimination_pass.h"

COMMON_DECLARE_bool(cinn_specify_input_dynamic_dim);
COMMON_DECLARE_string(cinn_input_dynamic_dim_spec_file);
COMMON_DECLARE_bool(print_ir);
COMMON_DECLARE_bool(pir_debug);
COMMON_DECLARE_bool(disable_dyshape_in_train);
Expand Down Expand Up @@ -97,6 +100,16 @@ void ApplyShapeOptimizationPass(
std::shared_ptr<pir::PassManager> pass_manager = CreatePassManager();
bool has_dynamic_shape = HasDynamicShape(*program);
if (has_dynamic_shape) {
if (FLAGS_cinn_specify_input_dynamic_dim) {
PADDLE_ENFORCE_NE(
FLAGS_cinn_input_dynamic_dim_spec_file,
"",
::common::errors::InvalidArgument(
"'FLAGS_cinn_input_dynamic_dim_spec_file' should not be empty "
"when using FLAGS_cinn_specify_input_dynamic_dim."));
SpecifyInputDynamicDimFromFile(program,
FLAGS_cinn_input_dynamic_dim_spec_file);
}
pass_manager->AddPass(pir::CreateShapeOptimizationPass());
}
pass_manager->Run(program);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/hlir/dialect/operator/transforms/specify_input_dynamic_dim_util.h"

#include <sys/stat.h>
#include <fstream>
#include "nlohmann/json.hpp"

using Json = nlohmann::json;

namespace cinn {
namespace dialect {
namespace ir {

namespace {

std::vector<pir::InputDynamicDimSpec> DeserializeInputDynamicDimSpecFromJson(
const Json& json) {
std::vector<pir::InputDynamicDimSpec> res;
for (const auto& element : json.items()) {
pir::InputDynamicDimSpec dim_spec;
dim_spec.dim_name = [&]() -> std::string { return element.key(); }();
dim_spec.input_bind = [&]() {
const auto& value = element.value();
std::vector<std::pair<std::string, int>> res;
PADDLE_ENFORCE_EQ(value.contains("input_bind"),
true,
::common::errors::InvalidArgument(
"input dynamic dim spec must contain input_bind"));
for (const auto& bind_item : value["input_bind"]) {
const auto& input_name = bind_item[0].get<std::string>();
const auto& dim_index = bind_item[1].get<int>();
res.emplace_back(std::make_pair(input_name, dim_index));
}
return res;
}();
dim_spec.range = [&]() {
const auto& value = element.value();
symbol::ConstraintsManager::Range range;
if (value.contains("min")) {
range.min = value["min"].get<int>();
}
if (value.contains("max")) {
range.max = value["max"].get<int>();
}
return range;
}();
res.emplace_back(std::move(dim_spec));
}
return res;
}

bool PathExists(const std::string& path) {
struct stat statbuf;
if (stat(path.c_str(), &statbuf) != -1) {
return true;
}
return false;
}

std::vector<pir::InputDynamicDimSpec>
DeserializeInputDynamicDimSpecFromJsonFile(std::string file_path) {
PADDLE_ENFORCE_EQ(
PathExists(file_path),
true,
::common::errors::InvalidArgument(
"File path for input dynamic dim spec not exists: %s.", file_path));
std::ifstream ifs(file_path);
PADDLE_ENFORCE_EQ(
!ifs,
false,
::common::errors::InvalidArgument(
"File path for input dynamic dim spec fail to open for reading: %s.",
file_path));
Json json;
ifs >> json;
return DeserializeInputDynamicDimSpecFromJson(json);
}

} // namespace

void SpecifyInputDynamicDim(
pir::Program* program,
const std::vector<pir::InputDynamicDimSpec>& input_dynamic_dim_spec) {
pir::ShapeConstraintIRAnalysis& shape_analysis =
pir::ShapeAnalysisManager::Instance().Get(program);
shape_analysis.SetInputDynamicDimSpec(input_dynamic_dim_spec);
}

void SpecifyInputDynamicDimFromFile(pir::Program* program,
std::string filepath) {
SpecifyInputDynamicDim(program,
DeserializeInputDynamicDimSpecFromJsonFile(filepath));
}

} // namespace ir
} // namespace dialect
} // namespace cinn
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/pir/include/core/program.h"
#include "paddle/pir/include/dialect/shape/utils/shape_analysis.h"

namespace cinn {
namespace dialect {
namespace ir {
void SpecifyInputDynamicDim(
pir::Program* program,
const std::vector<pir::InputDynamicDimSpec>& input_dynamic_dim_spec);
void SpecifyInputDynamicDimFromFile(pir::Program* program,
std::string filepath);
} // namespace ir
} // namespace dialect
} // namespace cinn
26 changes: 26 additions & 0 deletions paddle/common/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1104,6 +1104,32 @@ PHI_DEFINE_EXPORTED_string(cinn_subgraph_graphviz_dir,
"Specify the directory path of dot file of "
"graph, which is used for debug.");

/*
* CINN related FLAG
* Name: FLAGS_cinn_specify_input_dynamic_dim
* Since Version: develop
* Value Range: bool, default=false
* Example: FLAGS_cinn_specify_input_dynamic_dim=true will use file set by
* FLAGS_cinn_input_dynamic_dim_spec_file to specify input dynamic dimention.
*/
PHI_DEFINE_EXPORTED_bool(cinn_specify_input_dynamic_dim,
false,
"Whether to specify input dynamic dimention.");
Comment on lines +1107 to +1117
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个flag是不可以不用添加?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

之后考虑给path指定默认值,所以这里需要一个直接表明是否开启的开关


/*
* CINN related FLAG
* Name: FLAGS_cinn_input_dynamic_dim_spec_file
* Since Version: develop
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* Since Version: develop
* Since Version: 3.0 beta2

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

收到,下一个PR顺带修改

* Value Range: string, default=""
* Example: FLAGS_cinn_input_dynamic_dim_spec_file="./config.json",
* FLAGS_cinn_specify_input_dynamic_dim=true would use input dynamic dimention
* predefined in ./config.json to specify input dynamic dimention.
*/
PHI_DEFINE_EXPORTED_string(
cinn_input_dynamic_dim_spec_file,
"",
"File path of predefined input dynamic dimention specification.");

#endif

/*
Expand Down
23 changes: 12 additions & 11 deletions paddle/pir/include/dialect/shape/utils/shape_analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,11 @@ class IR_API InferSymbolicShapeCacheKey {
void SetInputShapeOrDatas(
const std::vector<symbol::ShapeOrDataDimExprs>& input_shape_or_datas);
};
struct ConstraintsForInputDimExpr {
symbol::DimExpr dim_expr;
// bind_info = [(input_name, dim_index)]
std::vector<std::pair<std::string, int>> bind_info;

struct InputDynamicDimSpec {
std::string dim_name;
// input_bind = [(input_name, dim_index)]
std::vector<std::pair<std::string, int>> input_bind;
symbol::ConstraintsManager::Range range;
};
} // namespace pir
Expand All @@ -81,8 +82,7 @@ class IR_API InferSymbolicShapeContext {
InferSymbolicShapeContext() = default;
InferSymbolicShapeContext(const InferSymbolicShapeContext&) = delete;
InferSymbolicShapeContext(InferSymbolicShapeContext&&) = delete;
void Init(
const std::vector<ConstraintsForInputDimExpr>& input_shape_constraints);
void Init(const std::vector<InputDynamicDimSpec>& input_dynamic_dim_spec);

// Note: Only initialize the symbol info, the value info is not update.
void RegisterSymbolConstraintFromContext(
Expand Down Expand Up @@ -164,6 +164,9 @@ class IR_API InferSymbolicShapeContext {

std::unordered_map<std::string, std::vector<DimIndexAndExpr>>
predefined_dimexpr_map_for_inputs_;

std::unordered_map<std::string, symbol::DimExpr>
input_dynamic_dim_name_spec_to_dimexpr_map_;
};

class IR_API ShapeConstraintIRAnalysis final
Expand Down Expand Up @@ -226,10 +229,8 @@ class IR_API ShapeConstraintIRAnalysis final
return context_.constraints_manager();
}

void SetInputShapeConstraints(
const std::vector<ConstraintsForInputDimExpr>& input_shape_constraints) {
input_shape_constraints_ = input_shape_constraints;
}
void SetInputDynamicDimSpec(
const std::vector<InputDynamicDimSpec>& input_dynamic_dim_spec);

private:
InferSymbolicShapeContext* MutInferSymbolicShapeContext() {
Expand All @@ -244,7 +245,7 @@ class IR_API ShapeConstraintIRAnalysis final

private:
InferSymbolicShapeContext context_;
std::vector<ConstraintsForInputDimExpr> input_shape_constraints_;
std::vector<InputDynamicDimSpec> input_dynamic_dim_spec_;
};

class IR_API ShapeAnalysisManager {
Expand Down
33 changes: 24 additions & 9 deletions paddle/pir/src/dialect/shape/utils/shape_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,32 +32,42 @@ static std::string GetValueId(Value val) {
}

void InferSymbolicShapeContext::Init(
const std::vector<ConstraintsForInputDimExpr>& input_shape_constraints) {
const std::vector<InputDynamicDimSpec>& input_dynamic_dim_spec) {
value_id_to_shape_or_data_.clear();
next_sym_idx_ = sym_idx_begin_;
constraints_manager_.SetEqualCallbackFunc(
[&](const symbol::DimExpr& lhs, const symbol::DimExpr& rhs) {
return SubstituteDimExpr(lhs, rhs);
});

const auto& InitBindInfoForInputDim =
[&](const std::vector<std::pair<std::string, int>>& bind_info,
const auto& CreateDimExprForInputDynamicDim = [&]() {
for (const auto& item : input_dynamic_dim_spec) {
input_dynamic_dim_name_spec_to_dimexpr_map_[item.dim_name] =
symbol::DimExpr{GetNextSymName()};
}
};

const auto& SetDynamicShapeInputBind =
[&](const std::vector<std::pair<std::string, int>>& input_bind,
const symbol::DimExpr& dim_expr) {
for (const auto& item : bind_info) {
for (const auto& item : input_bind) {
predefined_dimexpr_map_for_inputs_[item.first].emplace_back(
DimIndexAndExpr(item.second, dim_expr));
}
};

const auto& InitRangeForInputDim =
const auto& SetDynamicShapeInputRange =
[&](const symbol::ConstraintsManager::Range& range,
const symbol::DimExpr& dim_expr) {
constraints_manager_.AddInputRangeCstr(dim_expr, range);
};

for (const auto& item : input_shape_constraints) {
InitBindInfoForInputDim(item.bind_info, item.dim_expr);
InitRangeForInputDim(item.range, item.dim_expr);
CreateDimExprForInputDynamicDim();
for (const auto& item : input_dynamic_dim_spec) {
const auto& dim_expr =
input_dynamic_dim_name_spec_to_dimexpr_map_.at(item.dim_name);
SetDynamicShapeInputBind(item.input_bind, dim_expr);
SetDynamicShapeInputRange(item.range, dim_expr);
}
}

Expand Down Expand Up @@ -409,7 +419,7 @@ InferSymbolicShapeContext::GetPredefinedDimExprForInputName(
}

void ShapeConstraintIRAnalysis::InitInferContext() {
context_.Init(input_shape_constraints_);
context_.Init(input_dynamic_dim_spec_);
}

void ShapeConstraintIRAnalysis::RegisterSymbolConstraintFromShapeAnalysis(
Expand Down Expand Up @@ -745,6 +755,11 @@ pir::PrintHooks ShapeConstraintIRAnalysis::PrintHook() {
return print_hook;
}

void ShapeConstraintIRAnalysis::SetInputDynamicDimSpec(
const std::vector<InputDynamicDimSpec>& input_dynamic_dim_spec) {
input_dynamic_dim_spec_ = input_dynamic_dim_spec;
}

ShapeAnalysisManager& ShapeAnalysisManager::Instance() {
static ShapeAnalysisManager instance;
return instance;
Expand Down