Skip to content

Commit 24ef6c5

Browse files
committed
move scalar and polish enforce
1 parent 19b1095 commit 24ef6c5

File tree

15 files changed

+53
-35
lines changed

15 files changed

+53
-35
lines changed

paddle/fluid/framework/operator.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ limitations under the License. */
3030
#include "paddle/fluid/framework/var_type.h"
3131
#include "paddle/fluid/platform/enforce.h"
3232
#include "paddle/fluid/platform/profiler.h"
33+
#include "paddle/pten/common/scalar.h"
3334

3435
namespace paddle {
3536
namespace framework {

paddle/fluid/imperative/prepared_operator.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "paddle/fluid/framework/details/nan_inf_utils.h"
1919
#include "paddle/fluid/framework/pten_utils.h"
2020
#include "paddle/fluid/imperative/infer_shape_context.h"
21+
#include "paddle/pten/common/scalar.h"
2122
#include "paddle/utils/small_vector.h"
2223
#ifdef PADDLE_WITH_XPU
2324
#include "paddle/fluid/platform/xpu/xpu_op_list.h"

paddle/pten/api/include/core.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,4 @@ limitations under the License. */
1919
#include "paddle/pten/core/dense_tensor.h"
2020
#include "paddle/pten/core/kernel_context.h"
2121
#include "paddle/pten/core/kernel_factory.h"
22-
#include "paddle/pten/core/scalar.h"
2322
#include "paddle/pten/core/tensor_meta.h"

paddle/pten/common/backend.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ limitations under the License. */
1616

1717
#include <ostream>
1818

19+
#include "paddle/fluid/platform/enforce.h"
20+
1921
namespace paddle {
2022
namespace experimental {
2123

@@ -78,7 +80,8 @@ inline std::ostream& operator<<(std::ostream& os, Backend backend) {
7880
os << "CUDNN";
7981
break;
8082
default:
81-
throw std::runtime_error("Invalid Backend type.");
83+
PADDLE_THROW(platform::errors::InvalidArgument(
84+
"Invalid enum backend type `%d`.", static_cast<int>(backend)));
8285
}
8386
return os;
8487
}

paddle/pten/common/data_type.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ limitations under the License. */
1818
#include "paddle/fluid/platform/bfloat16.h"
1919
#include "paddle/fluid/platform/complex.h"
2020
#include "paddle/fluid/platform/enforce.h"
21-
#include "paddle/fluid/platform/errors.h"
2221
#include "paddle/fluid/platform/float16.h"
2322

2423
namespace paddle {
@@ -164,13 +163,13 @@ inline std::ostream& operator<<(std::ostream& os, DataType dtype) {
164163
os << "complex128";
165164
break;
166165
default:
167-
// TODO(chenweihang): change to enforce later
168-
throw std::runtime_error("Invalid DataType type.");
166+
PADDLE_THROW(platform::errors::InvalidArgument(
167+
"Invalid enum data type `%d`.", static_cast<int>(dtype)));
169168
}
170169
return os;
171170
}
172171

173-
inline DataType& operator++(DataType& dtype, int) {
172+
inline DataType& operator++(DataType dtype, int) {
174173
dtype =
175174
DataType(static_cast<std::underlying_type<DataType>::type>(dtype) + 1);
176175
return dtype;

paddle/pten/common/layout.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ limitations under the License. */
1414

1515
#pragma once
1616

17+
#include "paddle/fluid/platform/enforce.h"
18+
1719
namespace paddle {
1820
namespace experimental {
1921

@@ -26,8 +28,8 @@ enum class DataLayout {
2628
NUM_DATA_LAYOUTS,
2729
};
2830

29-
inline std::ostream& operator<<(std::ostream& os, DataLayout dtype) {
30-
switch (dtype) {
31+
inline std::ostream& operator<<(std::ostream& os, DataLayout layout) {
32+
switch (layout) {
3133
case DataLayout::UNDEFINED:
3234
os << "Undefined";
3335
break;
@@ -44,13 +46,13 @@ inline std::ostream& operator<<(std::ostream& os, DataLayout dtype) {
4446
os << "MKLDNN";
4547
break;
4648
default:
47-
// TODO(chenweihang): change to enforce later
48-
throw std::runtime_error("Invalid DataLayout type.");
49+
PADDLE_THROW(platform::errors::InvalidArgument(
50+
"Invalid enum data layout type `%d`.", static_cast<int>(layout)));
4951
}
5052
return os;
5153
}
5254

53-
inline DataLayout& operator++(DataLayout& layout, int) {
55+
inline DataLayout& operator++(DataLayout layout, int) {
5456
layout = DataLayout(
5557
static_cast<std::underlying_type<DataLayout>::type>(layout) + 1);
5658
return layout;

paddle/pten/core/scalar.h renamed to paddle/pten/common/scalar.h

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@ limitations under the License. */
1414

1515
#pragma once
1616

17-
namespace pten {
17+
#include <cstdint>
18+
19+
#include "paddle/fluid/platform/enforce.h"
20+
21+
namespace paddle {
22+
namespace experimental {
1823

1924
class Scalar {
2025
public:
@@ -43,7 +48,8 @@ class Scalar {
4348
case Tag::HAS_B:
4449
return static_cast<T>(data_.b);
4550
default:
46-
throw std::runtime_error("Invalid Scalar type.");
51+
PADDLE_THROW(platform::errors::InvalidArgument(
52+
"Invalid enum scalar type tag `%d`.", static_cast<int>(tag)));
4753
}
4854
}
4955

@@ -60,4 +66,9 @@ class Scalar {
6066
} data_;
6167
};
6268

63-
} // namespace pten
69+
} // namespace experimental
70+
} // namespace paddle
71+
72+
namespace pten {
73+
using Scalar = paddle::experimental::Scalar;
74+
}

paddle/pten/core/kernel_utils.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@
1414

1515
#pragma once
1616

17+
#include "paddle/pten/common/scalar.h"
1718
#include "paddle/pten/core/dense_tensor.h"
1819
#include "paddle/pten/core/kernel_context.h"
1920
#include "paddle/pten/core/kernel_def.h"
20-
#include "paddle/pten/core/scalar.h"
2121

2222
// See Note [ Why still include the fluid headers? ]
2323
#include "paddle/fluid/platform/device_context.h"
@@ -163,7 +163,7 @@ struct KernelImpl<Return (*)(Args...), kernel_fn> {
163163
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(int);
164164
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(int64_t);
165165
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(paddle::platform::float16);
166-
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const pten::Scalar&);
166+
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const Scalar&);
167167

168168
/* Output Helpers */
169169

paddle/pten/hapi/include/backend_set.h

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,8 @@ limitations under the License. */
1616

1717
#include <ostream>
1818

19-
// TODO(chenweihang): move this file into hapi/include when compile
19+
#include "paddle/fluid/platform/enforce.h"
2020
#include "paddle/pten/common/backend.h"
21-
2221
namespace paddle {
2322
namespace experimental {
2423

@@ -39,10 +38,10 @@ class BackendSet final {
3938
uint64_t bitset() const { return bitset_; }
4039

4140
bool inline Has(Backend b) const {
42-
// TODO(chenweihang): replace by internal assert method later
43-
if (b == Backend::UNDEFINED) {
44-
throw std::runtime_error("Backend argument can't be UNDEFINED.");
45-
}
41+
PADDLE_ENFORCE_NE(b,
42+
Backend::UNDEFINED,
43+
platform::errors::InvalidArgument(
44+
"Backend argument can't be UNDEFINED."));
4645
return static_cast<bool>(bitset_ & BackendSet(b).bitset());
4746
}
4847
bool IsEmpty() const { return bitset_ == 0; }

paddle/pten/hapi/include/creation.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@
1515
#pragma once
1616

1717
#include "paddle/pten/common/data_type.h"
18-
#include "paddle/pten/core/scalar.h"
18+
#include "paddle/pten/common/scalar.h"
1919
#include "paddle/pten/hapi/include/tensor.h"
2020

2121
namespace paddle {
2222
namespace experimental {
2323

2424
Tensor full_like(const Tensor& x,
25-
const pten::Scalar& value,
25+
const Scalar& value,
2626
DataType dtype = DataType::UNDEFINED);
2727

2828
Tensor ones_like(const Tensor& x, DataType dtype = DataType::UNDEFINED);

0 commit comments

Comments
 (0)