Skip to content

Commit 22f3066

Browse files
[CINN]Apply input dynamic dim specification (#67628)
* [CINN]Add config for cinn input shape constraint setting * fix
1 parent e0f3072 commit 22f3066

File tree

7 files changed

+220
-21
lines changed

7 files changed

+220
-21
lines changed

paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@ set(cinn_transforms_deps
88
op_dialect_vjp
99
cinn_runtime_dialect
1010
op_fusion
11-
pir_compiler)
11+
pir_compiler
12+
json)
13+
14+
include_directories(cinn_transforms PRIVATE
15+
${PADDLE_SOURCE_DIR}/third_party/nlohmann_json/include/)
1216

1317
cinn_cc_library(cinn_transforms SRCS ${cinn_transforms_srcs} DEPS
1418
${cinn_transforms_deps})

paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,13 @@
5252
#include "paddle/cinn/hlir/dialect/operator/transforms/remove_assign_out_pass.h"
5353
#include "paddle/cinn/hlir/dialect/operator/transforms/replace_dynamic_expand_pass.h"
5454
#include "paddle/cinn/hlir/dialect/operator/transforms/shape_ops_fallback_to_phi_pass.h"
55+
#include "paddle/cinn/hlir/dialect/operator/transforms/specify_input_dynamic_dim_util.h"
5556
#include "paddle/cinn/hlir/dialect/operator/transforms/split_generate_shape_into_shape_ops_pass.h"
5657
#include "paddle/fluid/pir/transforms/build_cinn_pass.h"
5758
#include "paddle/fluid/pir/transforms/general/dead_code_elimination_pass.h"
5859

60+
COMMON_DECLARE_bool(cinn_specify_input_dynamic_dim);
61+
COMMON_DECLARE_string(cinn_input_dynamic_dim_spec_file);
5962
COMMON_DECLARE_bool(print_ir);
6063
COMMON_DECLARE_bool(pir_debug);
6164
COMMON_DECLARE_bool(disable_dyshape_in_train);
@@ -97,6 +100,16 @@ void ApplyShapeOptimizationPass(
97100
std::shared_ptr<pir::PassManager> pass_manager = CreatePassManager();
98101
bool has_dynamic_shape = HasDynamicShape(*program);
99102
if (has_dynamic_shape) {
103+
if (FLAGS_cinn_specify_input_dynamic_dim) {
104+
PADDLE_ENFORCE_NE(
105+
FLAGS_cinn_input_dynamic_dim_spec_file,
106+
"",
107+
::common::errors::InvalidArgument(
108+
"'FLAGS_cinn_input_dynamic_dim_spec_file' should not be empty "
109+
"when using FLAGS_cinn_specify_input_dynamic_dim."));
110+
SpecifyInputDynamicDimFromFile(program,
111+
FLAGS_cinn_input_dynamic_dim_spec_file);
112+
}
100113
pass_manager->AddPass(pir::CreateShapeOptimizationPass());
101114
}
102115
pass_manager->Run(program);
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/cinn/hlir/dialect/operator/transforms/specify_input_dynamic_dim_util.h"
16+
17+
#include <sys/stat.h>
18+
#include <fstream>
19+
#include "nlohmann/json.hpp"
20+
21+
using Json = nlohmann::json;
22+
23+
namespace cinn {
24+
namespace dialect {
25+
namespace ir {
26+
27+
namespace {
28+
29+
std::vector<pir::InputDynamicDimSpec> DeserializeInputDynamicDimSpecFromJson(
30+
const Json& json) {
31+
std::vector<pir::InputDynamicDimSpec> res;
32+
for (const auto& element : json.items()) {
33+
pir::InputDynamicDimSpec dim_spec;
34+
dim_spec.dim_name = [&]() -> std::string { return element.key(); }();
35+
dim_spec.input_bind = [&]() {
36+
const auto& value = element.value();
37+
std::vector<std::pair<std::string, int>> res;
38+
PADDLE_ENFORCE_EQ(value.contains("input_bind"),
39+
true,
40+
::common::errors::InvalidArgument(
41+
"input dynamic dim spec must contain input_bind"));
42+
for (const auto& bind_item : value["input_bind"]) {
43+
const auto& input_name = bind_item[0].get<std::string>();
44+
const auto& dim_index = bind_item[1].get<int>();
45+
res.emplace_back(std::make_pair(input_name, dim_index));
46+
}
47+
return res;
48+
}();
49+
dim_spec.range = [&]() {
50+
const auto& value = element.value();
51+
symbol::ConstraintsManager::Range range;
52+
if (value.contains("min")) {
53+
range.min = value["min"].get<int>();
54+
}
55+
if (value.contains("max")) {
56+
range.max = value["max"].get<int>();
57+
}
58+
return range;
59+
}();
60+
res.emplace_back(std::move(dim_spec));
61+
}
62+
return res;
63+
}
64+
65+
bool PathExists(const std::string& path) {
66+
struct stat statbuf;
67+
if (stat(path.c_str(), &statbuf) != -1) {
68+
return true;
69+
}
70+
return false;
71+
}
72+
73+
std::vector<pir::InputDynamicDimSpec>
74+
DeserializeInputDynamicDimSpecFromJsonFile(std::string file_path) {
75+
PADDLE_ENFORCE_EQ(
76+
PathExists(file_path),
77+
true,
78+
::common::errors::InvalidArgument(
79+
"File path for input dynamic dim spec not exists: %s.", file_path));
80+
std::ifstream ifs(file_path);
81+
PADDLE_ENFORCE_EQ(
82+
!ifs,
83+
false,
84+
::common::errors::InvalidArgument(
85+
"File path for input dynamic dim spec fail to open for reading: %s.",
86+
file_path));
87+
Json json;
88+
ifs >> json;
89+
return DeserializeInputDynamicDimSpecFromJson(json);
90+
}
91+
92+
} // namespace
93+
94+
void SpecifyInputDynamicDim(
95+
pir::Program* program,
96+
const std::vector<pir::InputDynamicDimSpec>& input_dynamic_dim_spec) {
97+
pir::ShapeConstraintIRAnalysis& shape_analysis =
98+
pir::ShapeAnalysisManager::Instance().Get(program);
99+
shape_analysis.SetInputDynamicDimSpec(input_dynamic_dim_spec);
100+
}
101+
102+
void SpecifyInputDynamicDimFromFile(pir::Program* program,
103+
std::string filepath) {
104+
SpecifyInputDynamicDim(program,
105+
DeserializeInputDynamicDimSpecFromJsonFile(filepath));
106+
}
107+
108+
} // namespace ir
109+
} // namespace dialect
110+
} // namespace cinn
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "paddle/pir/include/core/program.h"
18+
#include "paddle/pir/include/dialect/shape/utils/shape_analysis.h"
19+
20+
namespace cinn {
21+
namespace dialect {
22+
namespace ir {
23+
void SpecifyInputDynamicDim(
24+
pir::Program* program,
25+
const std::vector<pir::InputDynamicDimSpec>& input_dynamic_dim_spec);
26+
void SpecifyInputDynamicDimFromFile(pir::Program* program,
27+
std::string filepath);
28+
} // namespace ir
29+
} // namespace dialect
30+
} // namespace cinn

paddle/common/flags.cc

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,6 +1104,32 @@ PHI_DEFINE_EXPORTED_string(cinn_subgraph_graphviz_dir,
11041104
"Specify the directory path of dot file of "
11051105
"graph, which is used for debug.");
11061106

1107+
/*
1108+
* CINN related FLAG
1109+
* Name: FLAGS_cinn_specify_input_dynamic_dim
1110+
* Since Version: develop
1111+
* Value Range: bool, default=false
1112+
* Example: FLAGS_cinn_specify_input_dynamic_dim=true will use file set by
1113+
* FLAGS_cinn_input_dynamic_dim_spec_file to specify input dynamic dimention.
1114+
*/
1115+
PHI_DEFINE_EXPORTED_bool(cinn_specify_input_dynamic_dim,
1116+
false,
1117+
"Whether to specify input dynamic dimention.");
1118+
1119+
/*
1120+
* CINN related FLAG
1121+
* Name: FLAGS_cinn_input_dynamic_dim_spec_file
1122+
* Since Version: develop
1123+
* Value Range: string, default=""
1124+
* Example: FLAGS_cinn_input_dynamic_dim_spec_file="./config.json",
1125+
* FLAGS_cinn_specify_input_dynamic_dim=true would use input dynamic dimention
1126+
* predefined in ./config.json to specify input dynamic dimention.
1127+
*/
1128+
PHI_DEFINE_EXPORTED_string(
1129+
cinn_input_dynamic_dim_spec_file,
1130+
"",
1131+
"File path of predefined input dynamic dimention specification.");
1132+
11071133
#endif
11081134

11091135
/*

paddle/pir/include/dialect/shape/utils/shape_analysis.h

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,11 @@ class IR_API InferSymbolicShapeCacheKey {
5656
void SetInputShapeOrDatas(
5757
const std::vector<symbol::ShapeOrDataDimExprs>& input_shape_or_datas);
5858
};
59-
struct ConstraintsForInputDimExpr {
60-
symbol::DimExpr dim_expr;
61-
// bind_info = [(input_name, dim_index)]
62-
std::vector<std::pair<std::string, int>> bind_info;
59+
60+
struct InputDynamicDimSpec {
61+
std::string dim_name;
62+
// input_bind = [(input_name, dim_index)]
63+
std::vector<std::pair<std::string, int>> input_bind;
6364
symbol::ConstraintsManager::Range range;
6465
};
6566
} // namespace pir
@@ -81,8 +82,7 @@ class IR_API InferSymbolicShapeContext {
8182
InferSymbolicShapeContext() = default;
8283
InferSymbolicShapeContext(const InferSymbolicShapeContext&) = delete;
8384
InferSymbolicShapeContext(InferSymbolicShapeContext&&) = delete;
84-
void Init(
85-
const std::vector<ConstraintsForInputDimExpr>& input_shape_constraints);
85+
void Init(const std::vector<InputDynamicDimSpec>& input_dynamic_dim_spec);
8686

8787
// Note: Only initialize the symbol info, the value info is not update.
8888
void RegisterSymbolConstraintFromContext(
@@ -164,6 +164,9 @@ class IR_API InferSymbolicShapeContext {
164164

165165
std::unordered_map<std::string, std::vector<DimIndexAndExpr>>
166166
predefined_dimexpr_map_for_inputs_;
167+
168+
std::unordered_map<std::string, symbol::DimExpr>
169+
input_dynamic_dim_name_spec_to_dimexpr_map_;
167170
};
168171

169172
class IR_API ShapeConstraintIRAnalysis final
@@ -226,10 +229,8 @@ class IR_API ShapeConstraintIRAnalysis final
226229
return context_.constraints_manager();
227230
}
228231

229-
void SetInputShapeConstraints(
230-
const std::vector<ConstraintsForInputDimExpr>& input_shape_constraints) {
231-
input_shape_constraints_ = input_shape_constraints;
232-
}
232+
void SetInputDynamicDimSpec(
233+
const std::vector<InputDynamicDimSpec>& input_dynamic_dim_spec);
233234

234235
private:
235236
InferSymbolicShapeContext* MutInferSymbolicShapeContext() {
@@ -244,7 +245,7 @@ class IR_API ShapeConstraintIRAnalysis final
244245

245246
private:
246247
InferSymbolicShapeContext context_;
247-
std::vector<ConstraintsForInputDimExpr> input_shape_constraints_;
248+
std::vector<InputDynamicDimSpec> input_dynamic_dim_spec_;
248249
};
249250

250251
class IR_API ShapeAnalysisManager {

paddle/pir/src/dialect/shape/utils/shape_analysis.cc

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,32 +32,42 @@ static std::string GetValueId(Value val) {
3232
}
3333

3434
void InferSymbolicShapeContext::Init(
35-
const std::vector<ConstraintsForInputDimExpr>& input_shape_constraints) {
35+
const std::vector<InputDynamicDimSpec>& input_dynamic_dim_spec) {
3636
value_id_to_shape_or_data_.clear();
3737
next_sym_idx_ = sym_idx_begin_;
3838
constraints_manager_.SetEqualCallbackFunc(
3939
[&](const symbol::DimExpr& lhs, const symbol::DimExpr& rhs) {
4040
return SubstituteDimExpr(lhs, rhs);
4141
});
4242

43-
const auto& InitBindInfoForInputDim =
44-
[&](const std::vector<std::pair<std::string, int>>& bind_info,
43+
const auto& CreateDimExprForInputDynamicDim = [&]() {
44+
for (const auto& item : input_dynamic_dim_spec) {
45+
input_dynamic_dim_name_spec_to_dimexpr_map_[item.dim_name] =
46+
symbol::DimExpr{GetNextSymName()};
47+
}
48+
};
49+
50+
const auto& SetDynamicShapeInputBind =
51+
[&](const std::vector<std::pair<std::string, int>>& input_bind,
4552
const symbol::DimExpr& dim_expr) {
46-
for (const auto& item : bind_info) {
53+
for (const auto& item : input_bind) {
4754
predefined_dimexpr_map_for_inputs_[item.first].emplace_back(
4855
DimIndexAndExpr(item.second, dim_expr));
4956
}
5057
};
5158

52-
const auto& InitRangeForInputDim =
59+
const auto& SetDynamicShapeInputRange =
5360
[&](const symbol::ConstraintsManager::Range& range,
5461
const symbol::DimExpr& dim_expr) {
5562
constraints_manager_.AddInputRangeCstr(dim_expr, range);
5663
};
5764

58-
for (const auto& item : input_shape_constraints) {
59-
InitBindInfoForInputDim(item.bind_info, item.dim_expr);
60-
InitRangeForInputDim(item.range, item.dim_expr);
65+
CreateDimExprForInputDynamicDim();
66+
for (const auto& item : input_dynamic_dim_spec) {
67+
const auto& dim_expr =
68+
input_dynamic_dim_name_spec_to_dimexpr_map_.at(item.dim_name);
69+
SetDynamicShapeInputBind(item.input_bind, dim_expr);
70+
SetDynamicShapeInputRange(item.range, dim_expr);
6171
}
6272
}
6373

@@ -409,7 +419,7 @@ InferSymbolicShapeContext::GetPredefinedDimExprForInputName(
409419
}
410420

411421
void ShapeConstraintIRAnalysis::InitInferContext() {
412-
context_.Init(input_shape_constraints_);
422+
context_.Init(input_dynamic_dim_spec_);
413423
}
414424

415425
void ShapeConstraintIRAnalysis::RegisterSymbolConstraintFromShapeAnalysis(
@@ -745,6 +755,11 @@ pir::PrintHooks ShapeConstraintIRAnalysis::PrintHook() {
745755
return print_hook;
746756
}
747757

758+
void ShapeConstraintIRAnalysis::SetInputDynamicDimSpec(
759+
const std::vector<InputDynamicDimSpec>& input_dynamic_dim_spec) {
760+
input_dynamic_dim_spec_ = input_dynamic_dim_spec;
761+
}
762+
748763
ShapeAnalysisManager& ShapeAnalysisManager::Instance() {
749764
static ShapeAnalysisManager instance;
750765
return instance;

0 commit comments

Comments
 (0)