Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions paddle/phi/kernels/funcs/pooling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -993,6 +993,7 @@ template class Pool2dDirectCUDAFunctor<AvgPool<float>, float>;
template class MaxPool2dGradFunctor<phi::GPUContext, float>;
template class MaxPool2dGradFunctor<phi::GPUContext, double>;
template class MaxPool2dGradFunctor<phi::GPUContext, dtype::float16>;
template class MaxPool2dGradFunctor<phi::GPUContext, dtype::bfloat16>;

template class Pool2dFunctor<phi::GPUContext, MaxPool<float>, float>;
template class Pool2dFunctor<phi::GPUContext, AvgPool<float>, float>;
Expand All @@ -1015,6 +1016,18 @@ template class Pool2dGradFunctor<phi::GPUContext,
template class Pool2dGradFunctor<phi::GPUContext,
AvgPoolGrad<dtype::float16>,
dtype::float16>;
template class Pool2dFunctor<phi::GPUContext,
MaxPool<dtype::bfloat16>,
dtype::bfloat16>;
template class Pool2dFunctor<phi::GPUContext,
AvgPool<dtype::bfloat16>,
dtype::bfloat16>;
template class Pool2dGradFunctor<phi::GPUContext,
MaxPoolGrad<dtype::bfloat16>,
dtype::bfloat16>;
template class Pool2dGradFunctor<phi::GPUContext,
AvgPoolGrad<dtype::bfloat16>,
dtype::bfloat16>;

template <typename PoolProcess, typename T>
__global__ void KernelPool3D(const int nthreads,
Expand Down Expand Up @@ -1863,6 +1876,7 @@ template class Pool3dDirectCUDAFunctor<AvgPool<float>, float>;
template class MaxPool3dGradFunctor<phi::GPUContext, float>;
template class MaxPool3dGradFunctor<phi::GPUContext, double>;
template class MaxPool3dGradFunctor<phi::GPUContext, dtype::float16>;
template class MaxPool3dGradFunctor<phi::GPUContext, dtype::bfloat16>;

template class Pool3dFunctor<phi::GPUContext, MaxPool<float>, float>;
template class Pool3dFunctor<phi::GPUContext, AvgPool<float>, float>;
Expand All @@ -1879,12 +1893,24 @@ template class Pool3dFunctor<phi::GPUContext,
template class Pool3dFunctor<phi::GPUContext,
AvgPool<dtype::float16>,
dtype::float16>;
template class Pool3dFunctor<phi::GPUContext,
MaxPool<dtype::bfloat16>,
dtype::bfloat16>;
template class Pool3dFunctor<phi::GPUContext,
AvgPool<dtype::bfloat16>,
dtype::bfloat16>;
template class Pool3dGradFunctor<phi::GPUContext,
MaxPoolGrad<dtype::float16>,
dtype::float16>;
template class Pool3dGradFunctor<phi::GPUContext,
AvgPoolGrad<dtype::float16>,
dtype::float16>;
template class Pool3dGradFunctor<phi::GPUContext,
MaxPoolGrad<dtype::bfloat16>,
dtype::bfloat16>;
template class Pool3dGradFunctor<phi::GPUContext,
AvgPoolGrad<dtype::bfloat16>,
dtype::bfloat16>;

template <typename T1, typename T2>
__global__ void KernelMaxPool2dWithIdx(const int nthreads,
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/funcs/select_impl.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ __device__ void SelectKernelImpl(OutT *out,
using IdT = int64_t;
// Set index data type
using Add = kps::AddFunctor<IdT>; // for cumsum
using Cast = NonZeroFunctor<InT>; // for mask
using Cast = NonZeroFunctor<MT>; // for mask

IdT init_idx = static_cast<IdT>(0.0f);
MT init_mask = static_cast<MT>(0.0f);
Expand Down
11 changes: 9 additions & 2 deletions paddle/phi/kernels/gpu/lgamma_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,14 @@
#include "paddle/phi/kernels/lgamma_grad_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/lgamma_grad_kernel_impl.h"
PD_REGISTER_KERNEL(
lgamma_grad, GPU, ALL_LAYOUT, phi::LgammaGradKernel, float, double) {}
PD_REGISTER_KERNEL(lgamma_grad,
GPU,
ALL_LAYOUT,
phi::LgammaGradKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
14 changes: 12 additions & 2 deletions paddle/phi/kernels/gpu/lgamma_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,17 @@
#include "paddle/phi/kernels/lgamma_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"

namespace phi {
template <typename T>
struct CudaLgammaFunctor {
__device__ __forceinline__ T operator()(const T x) const {
return Eigen::numext::lgamma(x);
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
const MT mp_x = static_cast<MT>(x);
return static_cast<T>(Eigen::numext::lgamma(mp_x));
}
};
template <typename T, typename Context>
Expand All @@ -38,4 +41,11 @@ void LgammaKernel(const Context& dev_ctx,
}
} // namespace phi

PD_REGISTER_KERNEL(lgamma, GPU, ALL_LAYOUT, phi::LgammaKernel, float, double) {}
PD_REGISTER_KERNEL(lgamma,
GPU,
ALL_LAYOUT,
phi::LgammaKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
5 changes: 4 additions & 1 deletion paddle/phi/kernels/gpu/masked_select_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <thrust/reverse.h>
#include <thrust/scan.h>

#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/select_impl.cu.h"

Expand Down Expand Up @@ -66,4 +67,6 @@ PD_REGISTER_KERNEL(masked_select_grad,
float,
double,
int,
int64_t) {}
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
5 changes: 4 additions & 1 deletion paddle/phi/kernels/gpu/masked_select_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <thrust/scan.h>

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/select_impl.cu.h"

Expand Down Expand Up @@ -76,6 +77,8 @@ PD_REGISTER_KERNEL(masked_select,
float,
double,
int,
int64_t) {
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->InputAt(1).SetDataType(phi::DataType::BOOL);
}
4 changes: 3 additions & 1 deletion paddle/phi/kernels/gpu/pool_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "paddle/phi/kernels/pool_grad_kernel.h"

#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/pool_grad_kernel_impl.h"
Expand Down Expand Up @@ -46,7 +47,8 @@ PD_REGISTER_KERNEL(pool3d_grad,
phi::Pool3dGradKernel,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(max_pool3d_with_index_grad,
GPU,
ALL_LAYOUT,
Expand Down
4 changes: 3 additions & 1 deletion paddle/phi/kernels/gpu/pool_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "paddle/phi/kernels/pool_kernel.h"

#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/pool_kernel_impl.h"
Expand All @@ -40,7 +41,8 @@ PD_REGISTER_KERNEL(pool3d,
phi::Pool3dKernel,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(max_pool3d_with_index,
GPU,
ALL_LAYOUT,
Expand Down
6 changes: 5 additions & 1 deletion paddle/phi/kernels/impl/lgamma_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#pragma once
#include <unsupported/Eigen/SpecialFunctions>

#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/kernels/funcs/for_range.h"
namespace phi {
template <typename T>
Expand All @@ -23,7 +24,10 @@ struct LgammaGradFunctor {
: dout_(dout), x_(x), output_(output), numel_(numel) {}

HOSTDEVICE void operator()(int64_t idx) const {
output_[idx] = dout_[idx] * Eigen::numext::digamma(x_[idx]);
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
const MT mp_dout = static_cast<MT>(dout_[idx]);
const MT mp_x = static_cast<MT>(x_[idx]);
output_[idx] = static_cast<T>(mp_dout * Eigen::numext::digamma(mp_x));
}

private:
Expand Down
38 changes: 37 additions & 1 deletion python/paddle/fluid/tests/unittests/test_lgamma_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
import unittest

import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16
from scipy import special

import paddle
from paddle.fluid import core

paddle.enable_static()

Expand Down Expand Up @@ -56,6 +57,41 @@ def test_check_grad_normal(self):
self.check_grad(['X'], 'Out', numeric_grad_delta=0.005)


class TestLgammaFP16Op(TestLgammaOp):
def init_dtype_type(self):
self.dtype = np.float16

def test_check_grad_normal(self):
self.check_grad(['X'], 'Out')


@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support bfloat16",
)
class TestLgammaBF16Op(OpTest):
def setUp(self):
self.op_type = 'lgamma'
self.python_api = paddle.lgamma
self.dtype = np.uint16
shape = (5, 20)
data = np.random.random(shape).astype("float32") + 1
self.inputs = {'X': convert_float_to_uint16(data)}
result = np.ones(shape).astype("float32")
for i in range(shape[0]):
for j in range(shape[1]):
result[i][j] = math.lgamma(data[i][j])
self.outputs = {'Out': convert_float_to_uint16(result)}

def test_check_output(self):
# After testing, bfloat16 needs to set the parameter place
self.check_output_with_place(core.CUDAPlace(0))

def test_check_grad_normal(self):
self.check_grad_with_place(core.CUDAPlace(0), ['X'], 'Out')


class TestLgammaOpApi(unittest.TestCase):
def test_lgamma(self):
paddle.disable_static()
Expand Down
72 changes: 71 additions & 1 deletion python/paddle/fluid/tests/unittests/test_masked_select_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
import unittest

import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16

import paddle
from paddle.fluid import core


def np_masked_select(x, mask):
Expand Down Expand Up @@ -59,6 +60,75 @@ def init(self):
self.shape = (168,)


class TestMaskedSelectFP16Op(OpTest):
def setUp(self):
self.init()
self.op_type = "masked_select"
self.dtype = np.float16
self.python_api = paddle.masked_select
x = np.random.random(self.shape).astype("float16")
mask = np.array(np.random.randint(2, size=self.shape, dtype=bool))
out = np_masked_select(x, mask)
self.inputs = {'X': x, 'Mask': mask}
self.outputs = {'Y': out}

def test_check_output(self):
self.check_output()

def test_check_grad(self):
self.check_grad(['X'], 'Y')

def init(self):
self.shape = (50, 3)


class TestMaskedSelectFP16Op1(TestMaskedSelectFP16Op):
def init(self):
self.shape = (6, 8, 9, 18)


class TestMaskedSelectFP16Op2(TestMaskedSelectFP16Op):
def init(self):
self.shape = (168,)


@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support bfloat16",
)
class TestMaskedSelectBF16Op(OpTest):
def setUp(self):
self.init()
self.op_type = "masked_select"
self.dtype = np.uint16
self.python_api = paddle.masked_select
x = np.random.random(self.shape).astype("float32")
mask = np.array(np.random.randint(2, size=self.shape, dtype=bool))
out = np_masked_select(x, mask)
self.inputs = {'X': convert_float_to_uint16(x), 'Mask': mask}
self.outputs = {'Y': convert_float_to_uint16(out)}

def test_check_output(self):
self.check_output_with_place(core.CUDAPlace(0))

def test_check_grad(self):
self.check_grad_with_place(core.CUDAPlace(0), ['X'], 'Y')

def init(self):
self.shape = (50, 3)


class TestMaskedSelectBF16Op1(TestMaskedSelectBF16Op):
def init(self):
self.shape = (6, 8, 9, 2)


class TestMaskedSelectBF16Op2(TestMaskedSelectBF16Op):
def init(self):
self.shape = (168,)


class TestMaskedSelectAPI(unittest.TestCase):
def test_imperative_mode(self):
paddle.disable_static()
Expand Down
Loading