Skip to content

Commit 869e08f

Browse files
committed
add ut
1 parent 8b7783b commit 869e08f

File tree

5 files changed

+80
-3
lines changed

5 files changed

+80
-3
lines changed

paddle/fluid/framework/ir/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,10 @@ cc_test(
322322
test_graph_pattern_detector
323323
SRCS graph_pattern_detector_tester.cc
324324
DEPS graph_pattern_detector)
325+
cc_test(
326+
test_save_optimized_model_pass
327+
SRCS save_optimized_model_pass_tester.cc
328+
DEPS save_optimized_model_pass)
325329
cc_test(
326330
test_op_compat_sensible_pass
327331
SRCS op_compat_sensible_pass_tester.cc

paddle/fluid/framework/ir/save_optimized_model_pass.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
1+
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.

paddle/fluid/framework/ir/save_optimized_model_pass.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
1+
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <gtest/gtest.h>
16+
#include "paddle/fluid/framework/ir/pass.h"
17+
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
18+
#include "paddle/fluid/inference/analysis/helper.h"
19+
20+
namespace paddle {
21+
namespace framework {
22+
namespace ir {
23+
24+
void AddVarToScope(Scope* param_scope,
25+
const std::string& name,
26+
const DDim& dims) {
27+
auto* tensor = param_scope->Var(name)->GetMutable<phi::DenseTensor>();
28+
tensor->Resize(dims);
29+
auto* cpu_ctx = static_cast<phi::CPUContext*>(
30+
platform::DeviceContextPool::Instance().Get(phi::CPUPlace()));
31+
cpu_ctx->Alloc<float>(tensor);
32+
}
33+
34+
VarDesc* Data(paddle::framework::BlockDesc* block,
35+
std::string name,
36+
std::vector<int64_t> shape = {},
37+
bool is_persistable = false,
38+
proto::VarType::Type data_type = proto::VarType::FP32) {
39+
auto* var = block->Var(name);
40+
var->SetType(proto::VarType::LOD_TENSOR);
41+
var->SetDataType(data_type);
42+
var->SetShape(shape);
43+
var->SetPersistable(is_persistable);
44+
return var;
45+
}
46+
47+
TEST(SaveOptimizedModelPass, basic) {
48+
paddle::framework::ProgramDesc program;
49+
auto* block = program.MutableBlock(0);
50+
auto* lookup_table_w = Data(block, "lookup_table_w", {1}, true);
51+
auto* lookup_table_out = Data(block, "scatter_out", {1});
52+
OpDesc* lookup_table = block->AppendOp();
53+
lookup_table->SetType("lookup_table_v2");
54+
lookup_table->SetInput("W", {lookup_table_w->Name()});
55+
lookup_table->SetOutput("Out", {lookup_table_out->Name()});
56+
57+
std::unique_ptr<ir::Graph> graph(new ir::Graph(program));
58+
auto scope = new Scope();
59+
AddVarToScope(scope, lookup_table_w->Name(), {1});
60+
graph->Set("__param_scope__", scope);
61+
62+
auto save_optimized_model_pass =
63+
PassRegistry::Instance().Get("save_optimized_model_pass");
64+
save_optimized_model_pass->Set("save_optimized_model", new bool(true));
65+
save_optimized_model_pass->Set("model_opt_cache_dir", new std::string(""));
66+
save_optimized_model_pass->Apply(graph.get());
67+
}
68+
69+
} // namespace ir
70+
} // namespace framework
71+
} // namespace paddle
72+
73+
USE_PASS(save_optimized_model_pass);

paddle/fluid/inference/api/paddle_analysis_config.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ struct PD_INFER_DECL AnalysisConfig {
202202
///
203203
/// \brief Save optimized model.
204204
///
205-
/// \param save_optimized_model Whether to enable save optimized model.
205+
/// \param save_optimized_model whether to enable save optimized model.
206206
///
207207
void EnableSaveOptimizedModel(bool save_optimized_model) {
208208
save_optimized_model_ = save_optimized_model;

0 commit comments

Comments
 (0)