Skip to content

Commit 88b43b5

Browse files
Add a new high performance framework for reduce ops (#32697)
1 parent 4920c47 commit 88b43b5

File tree

2 files changed

+704
-0
lines changed

2 files changed

+704
-0
lines changed
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include <string>
17+
#include <vector>
18+
19+
#include "paddle/fluid/framework/eigen.h"
20+
#include "paddle/fluid/framework/tensor.h"
21+
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
22+
#include "paddle/fluid/platform/device_context.h"
23+
#include "paddle/fluid/platform/hostdevice.h"
24+
#include "paddle/fluid/platform/macros.h"
25+
26+
namespace paddle {
27+
namespace operators {
28+
29+
template <typename T>
30+
struct CustomMin {
31+
__device__ __forceinline__ T operator()(const T &a, const T &b) const {
32+
return (b < a) ? b : a;
33+
}
34+
};
35+
36+
template <typename T>
37+
struct CustomMax {
38+
__device__ __forceinline__ T operator()(const T &a, const T &b) const {
39+
return (b > a) ? b : a;
40+
}
41+
};
42+
43+
template <typename T>
44+
struct CustomSum {
45+
__device__ __forceinline__ T operator()(const T &a, const T &b) const {
46+
return b + a;
47+
}
48+
};
49+
50+
template <typename T>
51+
struct CustomMul {
52+
__device__ __forceinline__ T operator()(const T &a, const T &b) const {
53+
return b * a;
54+
}
55+
};
56+
57+
} // namespace operators
58+
} // namespace paddle

0 commit comments

Comments
 (0)