Skip to content

Commit b622e96

Browse files
authored
【Hackathon 5th No.34】为 Paddle 新增 bitwise_right_shift / bitwise_right_shift_ / bitwise_left_shift / bitwise_left_shift_ API (#58092)
1 parent bba58af commit b622e96

File tree

12 files changed

+1074
-0
lines changed

12 files changed

+1074
-0
lines changed

paddle/phi/api/yaml/ops.yaml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,16 @@
343343
backend : x
344344
inplace: (x -> out)
345345

346+
- op : bitwise_left_shift
347+
args : (Tensor x, Tensor y, bool is_arithmetic = true)
348+
output : Tensor(out)
349+
infer_meta :
350+
func : BitwiseShiftInferMeta
351+
kernel :
352+
func : bitwise_left_shift
353+
backend : x
354+
inplace: (x -> out)
355+
346356
- op : bitwise_not
347357
args : (Tensor x)
348358
output : Tensor(out)
@@ -364,6 +374,16 @@
364374
backend : x
365375
inplace: (x -> out)
366376

377+
- op : bitwise_right_shift
378+
args : (Tensor x, Tensor y, bool is_arithmetic = true)
379+
output : Tensor(out)
380+
infer_meta :
381+
func : BitwiseShiftInferMeta
382+
kernel :
383+
func : bitwise_right_shift
384+
backend : x
385+
inplace: (x -> out)
386+
367387
- op : bitwise_xor
368388
args : (Tensor x, Tensor y)
369389
output : Tensor(out)

paddle/phi/infermeta/binary.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1269,6 +1269,13 @@ void ElementwiseInferMeta(const MetaTensor& x,
12691269
return ElementwiseRawInferMeta(x, y, -1, out);
12701270
}
12711271

1272+
void BitwiseShiftInferMeta(const MetaTensor& x,
1273+
const MetaTensor& y,
1274+
bool is_arithmetic,
1275+
MetaTensor* out) {
1276+
return ElementwiseRawInferMeta(x, y, -1, out);
1277+
}
1278+
12721279
void ElementwiseRawInferMeta(const MetaTensor& x,
12731280
const MetaTensor& y,
12741281
int axis,

paddle/phi/infermeta/binary.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,11 @@ void ElementwiseRawInferMeta(const MetaTensor& x_meta,
231231
MetaTensor* out,
232232
MetaConfig config = MetaConfig());
233233

234+
void BitwiseShiftInferMeta(const MetaTensor& x,
235+
const MetaTensor& y,
236+
bool is_arithmetic,
237+
MetaTensor* out);
238+
234239
void EmbeddingInferMeta(const MetaTensor& x,
235240
const MetaTensor& weight,
236241
int64_t padding_idx,

paddle/phi/kernels/bitwise_kernel.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,18 @@ void BitwiseNotKernel(const Context& dev_ctx,
4141
const DenseTensor& x,
4242
DenseTensor* out);
4343

44+
template <typename T, typename Context>
45+
void BitwiseLeftShiftKernel(const Context& dev_ctx,
46+
const DenseTensor& x,
47+
const DenseTensor& y,
48+
bool is_arithmetic,
49+
DenseTensor* out);
50+
51+
template <typename T, typename Context>
52+
void BitwiseRightShiftKernel(const Context& dev_ctx,
53+
const DenseTensor& x,
54+
const DenseTensor& y,
55+
bool is_arithmetic,
56+
DenseTensor* out);
57+
4458
} // namespace phi

paddle/phi/kernels/cpu/bitwise_kernel.cc

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,45 @@ DEFINE_BITWISE_KERNEL(Or)
4040
DEFINE_BITWISE_KERNEL(Xor)
4141
#undef DEFINE_BITWISE_KERNEL
4242

43+
#define DEFINE_BITWISE_KERNEL_WITH_INVERSE(op_type) \
44+
template <typename T, typename Context> \
45+
void Bitwise##op_type##Kernel(const Context& dev_ctx, \
46+
const DenseTensor& x, \
47+
const DenseTensor& y, \
48+
bool is_arithmetic, \
49+
DenseTensor* out) { \
50+
auto x_dims = x.dims(); \
51+
auto y_dims = y.dims(); \
52+
if (x_dims.size() >= y_dims.size()) { \
53+
if (is_arithmetic) { \
54+
funcs::Bitwise##op_type##ArithmeticFunctor<T> func; \
55+
funcs::ElementwiseCompute< \
56+
funcs::Bitwise##op_type##ArithmeticFunctor<T>, \
57+
T>(dev_ctx, x, y, func, out); \
58+
} else { \
59+
funcs::Bitwise##op_type##LogicFunctor<T> func; \
60+
funcs::ElementwiseCompute<funcs::Bitwise##op_type##LogicFunctor<T>, \
61+
T>(dev_ctx, x, y, func, out); \
62+
} \
63+
} else { \
64+
if (is_arithmetic) { \
65+
funcs::InverseBitwise##op_type##ArithmeticFunctor<T> inv_func; \
66+
funcs::ElementwiseCompute< \
67+
funcs::InverseBitwise##op_type##ArithmeticFunctor<T>, \
68+
T>(dev_ctx, x, y, inv_func, out); \
69+
} else { \
70+
funcs::InverseBitwise##op_type##LogicFunctor<T> inv_func; \
71+
funcs::ElementwiseCompute< \
72+
funcs::InverseBitwise##op_type##LogicFunctor<T>, \
73+
T>(dev_ctx, x, y, inv_func, out); \
74+
} \
75+
} \
76+
}
77+
78+
DEFINE_BITWISE_KERNEL_WITH_INVERSE(LeftShift)
79+
DEFINE_BITWISE_KERNEL_WITH_INVERSE(RightShift)
80+
#undef DEFINE_BITWISE_KERNEL_WITH_INVERSE
81+
4382
template <typename T, typename Context>
4483
void BitwiseNotKernel(const Context& dev_ctx,
4584
const DenseTensor& x,
@@ -97,3 +136,23 @@ PD_REGISTER_KERNEL(bitwise_not,
97136
int16_t,
98137
int,
99138
int64_t) {}
139+
140+
PD_REGISTER_KERNEL(bitwise_left_shift,
141+
CPU,
142+
ALL_LAYOUT,
143+
phi::BitwiseLeftShiftKernel,
144+
uint8_t,
145+
int8_t,
146+
int16_t,
147+
int,
148+
int64_t) {}
149+
150+
PD_REGISTER_KERNEL(bitwise_right_shift,
151+
CPU,
152+
ALL_LAYOUT,
153+
phi::BitwiseRightShiftKernel,
154+
uint8_t,
155+
int8_t,
156+
int16_t,
157+
int,
158+
int64_t) {}

paddle/phi/kernels/funcs/bitwise_functors.h

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,5 +47,164 @@ struct BitwiseNotFunctor<bool> {
4747
HOSTDEVICE bool operator()(const bool a) const { return !a; }
4848
};
4949

50+
template <typename T>
51+
struct BitwiseLeftShiftArithmeticFunctor {
52+
HOSTDEVICE T operator()(const T a, const T b) const {
53+
if (b >= static_cast<T>(sizeof(T) * 8)) return static_cast<T>(0);
54+
if (b < static_cast<T>(0)) return static_cast<T>(0);
55+
return a << b;
56+
}
57+
};
58+
59+
template <typename T>
60+
struct InverseBitwiseLeftShiftArithmeticFunctor {
61+
inline HOSTDEVICE T operator()(const T a, const T b) const {
62+
if (a >= static_cast<T>(sizeof(T) * 8)) return static_cast<T>(0);
63+
if (a < static_cast<T>(0)) return static_cast<T>(0);
64+
return b << a;
65+
}
66+
};
67+
68+
template <typename T>
69+
struct BitwiseLeftShiftLogicFunctor {
70+
HOSTDEVICE T operator()(const T a, const T b) const {
71+
if (b < static_cast<T>(0) || b >= static_cast<T>(sizeof(T) * 8))
72+
return static_cast<T>(0);
73+
return a << b;
74+
}
75+
};
76+
77+
template <typename T>
78+
struct InverseBitwiseLeftShiftLogicFunctor {
79+
inline HOSTDEVICE T operator()(const T a, const T b) const {
80+
if (a < static_cast<T>(0) || a >= static_cast<T>(sizeof(T) * 8))
81+
return static_cast<T>(0);
82+
return b << a;
83+
}
84+
};
85+
86+
template <typename T>
87+
struct BitwiseRightShiftArithmeticFunctor {
88+
HOSTDEVICE T operator()(const T a, const T b) const {
89+
if (b < static_cast<T>(0) || b >= static_cast<T>(sizeof(T) * 8))
90+
return static_cast<T>(-(a >> (sizeof(T) * 8 - 1) & 1));
91+
return a >> b;
92+
}
93+
};
94+
95+
template <typename T>
96+
struct InverseBitwiseRightShiftArithmeticFunctor {
97+
inline HOSTDEVICE T operator()(const T a, const T b) const {
98+
if (a < static_cast<T>(0) || a >= static_cast<T>(sizeof(T) * 8))
99+
return static_cast<T>(-(b >> (sizeof(T) * 8 - 1) & 1));
100+
return b >> a;
101+
}
102+
};
103+
104+
template <>
105+
struct BitwiseRightShiftArithmeticFunctor<uint8_t> {
106+
HOSTDEVICE uint8_t operator()(const uint8_t a, const uint8_t b) const {
107+
if (b >= static_cast<uint8_t>(sizeof(uint8_t) * 8))
108+
return static_cast<uint8_t>(0);
109+
return a >> b;
110+
}
111+
};
112+
113+
template <>
114+
struct InverseBitwiseRightShiftArithmeticFunctor<uint8_t> {
115+
inline HOSTDEVICE uint8_t operator()(const uint8_t a, const uint8_t b) const {
116+
if (a >= static_cast<uint8_t>(sizeof(uint8_t) * 8))
117+
return static_cast<uint8_t>(0);
118+
return b >> a;
119+
}
120+
};
121+
122+
template <typename T>
123+
struct BitwiseRightShiftLogicFunctor {
124+
HOSTDEVICE T operator()(const T a, const T b) const {
125+
if (b >= static_cast<T>(sizeof(T) * 8) || b < static_cast<T>(0))
126+
return static_cast<T>(0);
127+
return a >> b;
128+
}
129+
};
130+
131+
template <typename T>
132+
struct InverseBitwiseRightShiftLogicFunctor {
133+
inline HOSTDEVICE T operator()(const T a, const T b) const {
134+
if (a >= static_cast<T>(sizeof(T) * 8) || a < static_cast<T>(0))
135+
return static_cast<T>(0);
136+
return b >> a;
137+
}
138+
};
139+
140+
template <typename T>
141+
HOSTDEVICE T logic_shift_func(const T a, const T b) {
142+
if (b < static_cast<T>(0) || b >= static_cast<T>(sizeof(T) * 8))
143+
return static_cast<T>(0);
144+
T t = static_cast<T>(sizeof(T) * 8 - 1);
145+
T mask = (((a >> t) << t) >> b) << 1;
146+
return (a >> b) ^ mask;
147+
}
148+
149+
// signed int8
150+
template <>
151+
struct BitwiseRightShiftLogicFunctor<int8_t> {
152+
HOSTDEVICE int8_t operator()(const int8_t a, const int8_t b) const {
153+
return logic_shift_func<int8_t>(a, b);
154+
}
155+
};
156+
157+
template <>
158+
struct InverseBitwiseRightShiftLogicFunctor<int8_t> {
159+
inline HOSTDEVICE int8_t operator()(const int8_t a, const int8_t b) const {
160+
return logic_shift_func<int8_t>(b, a);
161+
}
162+
};
163+
164+
// signed int16
165+
template <>
166+
struct BitwiseRightShiftLogicFunctor<int16_t> {
167+
HOSTDEVICE int16_t operator()(const int16_t a, const int16_t b) const {
168+
return logic_shift_func<int16_t>(a, b);
169+
}
170+
};
171+
172+
template <>
173+
struct InverseBitwiseRightShiftLogicFunctor<int16_t> {
174+
inline HOSTDEVICE int16_t operator()(const int16_t a, const int16_t b) const {
175+
return logic_shift_func<int16_t>(b, a);
176+
}
177+
};
178+
179+
// signed int32
180+
template <>
181+
struct BitwiseRightShiftLogicFunctor<int> {
182+
HOSTDEVICE int operator()(const int a, const int b) const {
183+
return logic_shift_func<int32_t>(a, b);
184+
}
185+
};
186+
187+
template <>
188+
struct InverseBitwiseRightShiftLogicFunctor<int> {
189+
inline HOSTDEVICE int operator()(const int a, const int b) const {
190+
return logic_shift_func<int32_t>(b, a);
191+
}
192+
};
193+
194+
// signed int64
195+
template <>
196+
struct BitwiseRightShiftLogicFunctor<int64_t> {
197+
HOSTDEVICE int64_t operator()(const int64_t a, const int64_t b) const {
198+
return logic_shift_func<int64_t>(a, b);
199+
}
200+
};
201+
202+
template <>
203+
struct InverseBitwiseRightShiftLogicFunctor<int64_t> {
204+
inline HOSTDEVICE int64_t operator()(const int64_t a, const int64_t b) const {
205+
return logic_shift_func<int64_t>(b, a);
206+
}
207+
};
208+
50209
} // namespace funcs
51210
} // namespace phi

paddle/phi/kernels/kps/bitwise_kernel.cu

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,29 @@ DEFINE_BITWISE_KERNEL(Or)
4343
DEFINE_BITWISE_KERNEL(Xor)
4444
#undef DEFINE_BITWISE_KERNEL
4545

46+
#define DEFINE_BITWISE_KERNEL_WITH_BOOL(op_type) \
47+
template <typename T, typename Context> \
48+
void Bitwise##op_type##Kernel(const Context& dev_ctx, \
49+
const DenseTensor& x, \
50+
const DenseTensor& y, \
51+
bool is_arithmetic, \
52+
DenseTensor* out) { \
53+
dev_ctx.template Alloc<T>(out); \
54+
std::vector<const DenseTensor*> ins = {&x, &y}; \
55+
std::vector<DenseTensor*> outs = {out}; \
56+
if (is_arithmetic) { \
57+
funcs::Bitwise##op_type##ArithmeticFunctor<T> func; \
58+
funcs::BroadcastKernel<T>(dev_ctx, ins, &outs, func); \
59+
} else { \
60+
funcs::Bitwise##op_type##LogicFunctor<T> func; \
61+
funcs::BroadcastKernel<T>(dev_ctx, ins, &outs, func); \
62+
} \
63+
}
64+
65+
DEFINE_BITWISE_KERNEL_WITH_BOOL(LeftShift)
66+
DEFINE_BITWISE_KERNEL_WITH_BOOL(RightShift)
67+
#undef DEFINE_BITWISE_KERNEL_WITH_BOOL
68+
4669
template <typename T, typename Context>
4770
void BitwiseNotKernel(const Context& dev_ctx,
4871
const DenseTensor& x,
@@ -112,4 +135,24 @@ PD_REGISTER_KERNEL(bitwise_not,
112135
int,
113136
int64_t) {}
114137

138+
PD_REGISTER_KERNEL(bitwise_left_shift,
139+
KPS,
140+
ALL_LAYOUT,
141+
phi::BitwiseLeftShiftKernel,
142+
uint8_t,
143+
int8_t,
144+
int16_t,
145+
int,
146+
int64_t) {}
147+
148+
PD_REGISTER_KERNEL(bitwise_right_shift,
149+
KPS,
150+
ALL_LAYOUT,
151+
phi::BitwiseRightShiftKernel,
152+
uint8_t,
153+
int8_t,
154+
int16_t,
155+
int,
156+
int64_t) {}
157+
115158
#endif

python/paddle/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,10 @@
358358
atan_,
359359
atanh,
360360
atanh_,
361+
bitwise_left_shift,
362+
bitwise_left_shift_,
363+
bitwise_right_shift,
364+
bitwise_right_shift_,
361365
broadcast_shape,
362366
ceil,
363367
clip,
@@ -944,6 +948,10 @@
944948
'i1e',
945949
'polygamma',
946950
'polygamma_',
951+
'bitwise_left_shift',
952+
'bitwise_left_shift_',
953+
'bitwise_right_shift',
954+
'bitwise_right_shift_',
947955
'masked_fill',
948956
'masked_fill_',
949957
'masked_scatter',

0 commit comments

Comments
 (0)