Skip to content

Commit 216bfcc

Browse files
authored
[XPU] ut for save and load op (#65656)
1 parent a30c8a5 commit 216bfcc

File tree

5 files changed

+161
-0
lines changed

5 files changed

+161
-0
lines changed

paddle/fluid/operators/save_op.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,23 @@ PD_REGISTER_KERNEL(save,
105105
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
106106
}
107107

108+
#ifdef PADDLE_WITH_XPU
109+
PD_REGISTER_KERNEL(save,
110+
XPU,
111+
ALL_LAYOUT,
112+
ops::SaveKernel,
113+
float,
114+
double,
115+
int,
116+
uint8_t,
117+
int8_t,
118+
int64_t,
119+
phi::dtype::float16,
120+
phi::dtype::bfloat16) {
121+
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
122+
}
123+
#endif
124+
108125
PD_REGISTER_KERNEL(save_sr,
109126
CPU,
110127
ALL_LAYOUT,

paddle/phi/backends/xpu/xpu2_op_list.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -822,6 +822,15 @@ XPUOpMap& get_kl2_ops() {
822822
{"roll_grad", XPUKernelSet({phi::DataType::FLOAT32})},
823823
{"rsqrt", XPUKernelSet({phi::DataType::FLOAT32})},
824824
{"rsqrt_grad", XPUKernelSet({phi::DataType::FLOAT32})},
825+
{"save",
826+
XPUKernelSet({phi::DataType::FLOAT32,
827+
phi::DataType::FLOAT64,
828+
phi::DataType::INT32,
829+
phi::DataType::UINT8,
830+
phi::DataType::INT8,
831+
phi::DataType::INT64,
832+
phi::DataType::FLOAT16,
833+
phi::DataType::BFLOAT16})},
825834
{"scale",
826835
XPUKernelSet({phi::DataType::FLOAT32,
827836
phi::DataType::FLOAT16,

paddle/phi/backends/xpu/xpu3_op_list.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -830,6 +830,15 @@ XPUOpMap& get_kl3_ops() {
830830
{"roll_grad", XPUKernelSet({phi::DataType::FLOAT32})},
831831
{"rsqrt", XPUKernelSet({phi::DataType::FLOAT32})},
832832
{"rsqrt_grad", XPUKernelSet({phi::DataType::FLOAT32})},
833+
{"save",
834+
XPUKernelSet({phi::DataType::FLOAT32,
835+
phi::DataType::FLOAT64,
836+
phi::DataType::INT32,
837+
phi::DataType::UINT8,
838+
phi::DataType::INT8,
839+
phi::DataType::INT64,
840+
phi::DataType::FLOAT16,
841+
phi::DataType::BFLOAT16})},
833842
{"scale",
834843
XPUKernelSet({phi::DataType::FLOAT32,
835844
phi::DataType::FLOAT16,

test/cpp/fluid/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ paddle_test(assign_op_test SRCS assign_op_test.cc)
2828
paddle_test(scatter_test SRCS scatter_test.cc DEPS common)
2929
paddle_test(beam_search_decode_op_test SRCS beam_search_decode_op_test.cc)
3030
paddle_test(save_load_op_test SRCS save_load_op_test.cc)
31+
if(WITH_XPU)
32+
paddle_test(save_load_op_test_xpu SRCS save_load_op_test_xpu.cc)
33+
endif()
3134
paddle_test(save_load_combine_op_test SRCS save_load_combine_op_test.cc)
3235
if(WITH_CINN)
3336
set(CINN_DEPS python)
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "gtest/gtest.h"
16+
#include "paddle/fluid/framework/op_registry.h"
17+
#include "paddle/fluid/platform/float16.h"
18+
#include "paddle/phi/core/kernel_registry.h"
19+
20+
template <typename Place, typename T>
21+
int SaveLoadOpTest(Place place, int dim_1, int dim_2) {
22+
// use cpu place for ground truth
23+
paddle::platform::CPUPlace cpu_place;
24+
std::vector<T> ground_truth_cpu(dim_1 * dim_2);
25+
for (int i = 0; i < dim_1 * dim_2; i++) {
26+
ground_truth_cpu[i] = static_cast<T>(i);
27+
}
28+
29+
// scope, var, tensor and lod
30+
paddle::framework::Scope scope;
31+
auto var = scope.Var("test_var");
32+
auto tensor = var->GetMutable<phi::DenseTensor>();
33+
tensor->Resize({dim_1, dim_2});
34+
paddle::framework::LoD expect_lod;
35+
expect_lod.resize(1);
36+
for (int i = 0; i < dim_1; i++) {
37+
expect_lod[0].push_back(i);
38+
}
39+
tensor->set_lod(expect_lod);
40+
T* src_mutable = tensor->mutable_data<T>(place);
41+
// copy cpu data to tensor
42+
paddle::memory::Copy(place,
43+
src_mutable,
44+
cpu_place,
45+
ground_truth_cpu.data(),
46+
sizeof(T) * ground_truth_cpu.size());
47+
48+
// run save op
49+
paddle::framework::AttributeMap attrs;
50+
attrs.insert({"file_path", std::string("tensor.save")});
51+
auto save_op = paddle::framework::OpRegistry::CreateOp(
52+
"save", {{"X", {"test_var"}}}, {}, attrs);
53+
save_op->Run(scope, place);
54+
55+
// result var and tensor
56+
auto load_var = scope.Var("out_var");
57+
auto target = load_var->GetMutable<phi::DenseTensor>();
58+
59+
// run load op
60+
auto load_op = paddle::framework::OpRegistry::CreateOp(
61+
"load", {}, {{"Out", {"out_var"}}}, attrs);
62+
load_op->Run(scope, place);
63+
64+
// copy result tensor data to cpu
65+
T* actual = target->data<T>();
66+
std::vector<T> actual_cpu(dim_1 * dim_2);
67+
paddle::memory::Copy(cpu_place,
68+
actual_cpu.data(),
69+
place,
70+
actual,
71+
sizeof(T) * ground_truth_cpu.size());
72+
73+
// check result: data
74+
for (int i = 0; i < dim_1 * dim_2; i++) {
75+
if (actual_cpu[i] != ground_truth_cpu[i]) {
76+
return 1;
77+
}
78+
}
79+
80+
// check result: lod
81+
auto& actual_lod = target->lod();
82+
if (expect_lod.size() != actual_lod.size()) {
83+
return 1;
84+
}
85+
for (size_t i = 0; i < expect_lod.size(); ++i) { // NOLINT
86+
for (size_t j = 0; j < expect_lod[i].size(); ++j) {
87+
if (expect_lod[i][j] != actual_lod[i][j]) {
88+
return 1;
89+
}
90+
}
91+
}
92+
return 0;
93+
}
94+
95+
TEST(SaveLoadOp, XPU) {
96+
paddle::platform::XPUPlace xpu_place(0);
97+
paddle::platform::CPUPlace cpu_place;
98+
int r = 0;
99+
100+
r = SaveLoadOpTest<paddle::platform::XPUPlace, float>(xpu_place, 3, 10);
101+
EXPECT_EQ(r, 0);
102+
r = SaveLoadOpTest<paddle::platform::CPUPlace, float>(cpu_place, 3, 10);
103+
EXPECT_EQ(r, 0);
104+
105+
r = SaveLoadOpTest<paddle::platform::XPUPlace, int>(xpu_place, 2, 128);
106+
EXPECT_EQ(r, 0);
107+
r = SaveLoadOpTest<paddle::platform::CPUPlace, int>(cpu_place, 2, 128);
108+
EXPECT_EQ(r, 0);
109+
110+
r = SaveLoadOpTest<paddle::platform::XPUPlace, paddle::platform::float16>(
111+
xpu_place, 2, 128);
112+
EXPECT_EQ(r, 0);
113+
r = SaveLoadOpTest<paddle::platform::CPUPlace, paddle::platform::float16>(
114+
cpu_place, 2, 128);
115+
EXPECT_EQ(r, 0);
116+
117+
r = SaveLoadOpTest<paddle::platform::XPUPlace, paddle::platform::bfloat16>(
118+
xpu_place, 4, 32);
119+
EXPECT_EQ(r, 0);
120+
r = SaveLoadOpTest<paddle::platform::CPUPlace, paddle::platform::bfloat16>(
121+
cpu_place, 4, 32);
122+
EXPECT_EQ(r, 0);
123+
}

0 commit comments

Comments
 (0)