Skip to content

Commit 52fead0

Browse files
committed
add several base unittests
1 parent 7b7e988 commit 52fead0

File tree

13 files changed

+179
-40
lines changed

13 files changed

+179
-40
lines changed

paddle/pten/hapi/include/math.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ limitations under the License. */
1919
namespace paddle {
2020
namespace experimental {
2121

22+
// TODO(chenweihang): add scale API
23+
// TODO(chenweihang): move mean API into stat.h/cc
2224
Tensor mean(const Tensor& x);
2325

2426
} // namespace experimental

paddle/pten/tests/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
cc_test(pten_backend_test SRCS backend_test.cc DEPS gtest)
2+
cc_test(pten_data_layout_test SRCS data_layout_test.cc DEPS gtest)
3+
cc_test(pten_data_type_test SRCS data_type_test.cc DEPS gtest)
14
cc_test(dense_tensor_test SRCS dense_tensor_test.cc DEPS dense_tensor)
25
cc_test(kernel_factory_test SRCS kernel_factory_test.cc DEPS kernel_factory)
36
cc_test(test_mean_api SRCS test_mean_api.cc DEPS math_api)

paddle/pten/tests/backend_test.cc

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,35 @@ limitations under the License. */
1515
#include "paddle/pten/common/backend.h"
1616

1717
#include <gtest/gtest.h>
18+
#include <iostream>
19+
20+
TEST(Backend, OStream) {
21+
std::ostringstream oss;
22+
oss << pten::Backend::UNDEFINED;
23+
EXPECT_EQ(oss.str(), "Undefined");
24+
oss.str("");
25+
oss << pten::Backend::CPU;
26+
EXPECT_EQ(oss.str(), "CPU");
27+
oss.str("");
28+
oss << pten::Backend::CUDA;
29+
EXPECT_EQ(oss.str(), "CUDA");
30+
oss.str("");
31+
oss << pten::Backend::XPU;
32+
EXPECT_EQ(oss.str(), "XPU");
33+
oss.str("");
34+
oss << pten::Backend::NPU;
35+
EXPECT_EQ(oss.str(), "NPU");
36+
oss.str("");
37+
oss << pten::Backend::MKLDNN;
38+
EXPECT_EQ(oss.str(), "MKLDNN");
39+
oss.str("");
40+
oss << pten::Backend::CUDNN;
41+
EXPECT_EQ(oss.str(), "CUDNN");
42+
oss.str("");
43+
try {
44+
oss << pten::Backend::NUM_BACKENDS;
45+
} catch (paddle::platform::EnforceNotMet &exception) {
46+
std::string ex_msg = exception.what();
47+
EXPECT_TRUE(ex_msg.find("Invalid enum backend type") != std::string::npos);
48+
}
49+
}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <gtest/gtest.h>
16+
#include <iostream>
17+
#include <sstream>
18+
#include "paddle/pten/common/layout.h"
19+
20+
TEST(DataLayout, OStream) {
21+
std::ostringstream oss;
22+
oss << pten::DataLayout::UNDEFINED;
23+
EXPECT_EQ(oss.str(), "Undefined");
24+
oss.str("");
25+
oss << pten::DataLayout::ANY;
26+
EXPECT_EQ(oss.str(), "Any");
27+
oss.str("");
28+
oss << pten::DataLayout::NHWC;
29+
EXPECT_EQ(oss.str(), "NHWC");
30+
oss.str("");
31+
oss << pten::DataLayout::NCHW;
32+
EXPECT_EQ(oss.str(), "NCHW");
33+
oss.str("");
34+
oss << pten::DataLayout::MKLDNN;
35+
EXPECT_EQ(oss.str(), "MKLDNN");
36+
oss.str("");
37+
try {
38+
oss << pten::DataLayout::NUM_DATA_LAYOUTS;
39+
} catch (paddle::platform::EnforceNotMet &exception) {
40+
std::string ex_msg = exception.what();
41+
EXPECT_TRUE(ex_msg.find("Invalid enum data layout type") !=
42+
std::string::npos);
43+
}
44+
}
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/pten/common/data_type.h"
16+
17+
#include <gtest/gtest.h>
18+
#include <iostream>
19+
#include <sstream>
20+
21+
TEST(DataType, OStream) {
22+
std::ostringstream oss;
23+
oss << pten::DataType::UNDEFINED;
24+
EXPECT_EQ(oss.str(), "Undefined");
25+
oss.str("");
26+
oss << pten::DataType::BOOL;
27+
EXPECT_EQ(oss.str(), "bool");
28+
oss.str("");
29+
oss << pten::DataType::INT8;
30+
EXPECT_EQ(oss.str(), "int8");
31+
oss.str("");
32+
oss << pten::DataType::UINT8;
33+
EXPECT_EQ(oss.str(), "uint8");
34+
oss.str("");
35+
oss << pten::DataType::INT16;
36+
EXPECT_EQ(oss.str(), "int16");
37+
oss.str("");
38+
oss << pten::DataType::INT32;
39+
EXPECT_EQ(oss.str(), "int32");
40+
oss.str("");
41+
oss << pten::DataType::INT64;
42+
EXPECT_EQ(oss.str(), "int64");
43+
oss.str("");
44+
oss << pten::DataType::BFLOAT16;
45+
EXPECT_EQ(oss.str(), "bfloat16");
46+
oss.str("");
47+
oss << pten::DataType::FLOAT16;
48+
EXPECT_EQ(oss.str(), "float16");
49+
oss.str("");
50+
oss << pten::DataType::FLOAT32;
51+
EXPECT_EQ(oss.str(), "float32");
52+
oss.str("");
53+
oss << pten::DataType::FLOAT64;
54+
EXPECT_EQ(oss.str(), "float64");
55+
oss.str("");
56+
oss << pten::DataType::COMPLEX64;
57+
EXPECT_EQ(oss.str(), "complex64");
58+
oss.str("");
59+
oss << pten::DataType::COMPLEX128;
60+
EXPECT_EQ(oss.str(), "complex128");
61+
oss.str("");
62+
try {
63+
oss << pten::DataType::NUM_DATA_TYPES;
64+
} catch (paddle::platform::EnforceNotMet &exception) {
65+
std::string ex_msg = exception.what();
66+
EXPECT_TRUE(ex_msg.find("Invalid enum data type") != std::string::npos);
67+
}
68+
}

paddle/pten/tests/dense_tensor_test.cc

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,3 @@ TEST(DenseTensor, Constructor) {
3131
ASSERT_EQ(tensor.data_type(), pten::DataType::FLOAT32);
3232
ASSERT_EQ(tensor.layout(), pten::DataLayout::NCHW);
3333
}
34-
35-
TEST(DenseTensor, Dims) {
36-
// impl later
37-
}
38-
39-
TEST(DenseTensor, Place) {
40-
// impl later
41-
}
42-
43-
TEST(DenseTensor, Data) {
44-
// impl later
45-
}

paddle/pten/tests/dtype_test.cc

Lines changed: 0 additions & 13 deletions
This file was deleted.

paddle/pten/tests/kernel_factory_test.cc

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,36 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15+
#include <iostream>
16+
#include <sstream>
17+
1518
#include "paddle/pten/core/kernel_factory.h"
1619

1720
#include "gtest/gtest.h"
1821

19-
TEST(KernelFactory, KernelKey) {
22+
// TODO(chenweihang): add more unittests later
23+
24+
TEST(KernelName, ConstructAndOStream) {
25+
std::ostringstream oss;
26+
oss << pten::KernelName("scale", "host");
27+
EXPECT_EQ(oss.str(), "scale.host");
28+
pten::KernelName kernel_name1("scale.host");
29+
EXPECT_EQ(kernel_name1.name(), "scale");
30+
EXPECT_EQ(kernel_name1.overload_name(), "host");
31+
pten::KernelName kernel_name2("scale.host");
32+
EXPECT_EQ(kernel_name2.name(), "scale");
33+
EXPECT_EQ(kernel_name2.overload_name(), "host");
34+
}
35+
36+
TEST(KernelKey, ConstructAndOStream) {
2037
pten::KernelKey key(
2138
pten::Backend::CPU, pten::DataLayout::NCHW, pten::DataType::FLOAT32);
22-
std::cout << key;
39+
EXPECT_EQ(key.backend(), pten::Backend::CPU);
40+
EXPECT_EQ(key.layout(), pten::DataLayout::NCHW);
41+
EXPECT_EQ(key.dtype(), pten::DataType::FLOAT32);
42+
std::ostringstream oss;
43+
oss << key;
44+
std::cout << oss.str();
45+
// EXPECT_EQ(oss.str(), "scale.host");
46+
oss.flush();
2347
}

paddle/pten/tests/layout_test.cc

Lines changed: 0 additions & 13 deletions
This file was deleted.

paddle/pten/tests/test_dot_api.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ PT_DECLARE_MODULE(LinalgCUDA);
2929
namespace framework = paddle::framework;
3030
using DDim = paddle::framework::DDim;
3131

32+
// TODO(chenweihang): Remove this test after the API is used in the dygraph
3233
TEST(API, dot) {
3334
// 1. create tensor
3435
auto dense_x = std::make_shared<pten::DenseTensor>(

0 commit comments

Comments
 (0)