Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions paddle/phi/kernels/gpu/digamma_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,16 @@
#include "paddle/phi/kernels/digamma_grad_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/digamma_grad_kernel_impl.h"

PD_REGISTER_KERNEL(
digamma_grad, GPU, ALL_LAYOUT, phi::DigammaGradKernel, float, double) {}
PD_REGISTER_KERNEL(digamma_grad,
GPU,
ALL_LAYOUT,
phi::DigammaGradKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
11 changes: 9 additions & 2 deletions paddle/phi/kernels/gpu/digamma_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,17 @@
#include "paddle/phi/kernels/digamma_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/digamma_kernel_impl.h"

PD_REGISTER_KERNEL(
digamma, GPU, ALL_LAYOUT, phi::DigammaKernel, float, double) {}
PD_REGISTER_KERNEL(digamma,
GPU,
ALL_LAYOUT,
phi::DigammaKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
10 changes: 8 additions & 2 deletions paddle/phi/kernels/gpu/dirichlet_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -112,5 +112,11 @@ struct DirichletSampler<GPUContext, T> {
};
} // namespace phi

PD_REGISTER_KERNEL(
dirichlet, GPU, ALL_LAYOUT, phi::Dirichletkernel, float, double) {}
PD_REGISTER_KERNEL(dirichlet,
GPU,
ALL_LAYOUT,
phi::Dirichletkernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
7 changes: 6 additions & 1 deletion paddle/phi/kernels/impl/digamma_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <unsupported/Eigen/SpecialFunctions>

#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/for_range.h"

Expand All @@ -27,7 +28,11 @@ struct DigammaGradFunctor {
: dout_(dout), x_(x), output_(output), numel_(numel) {}

HOSTDEVICE void operator()(int64_t idx) const {
output_[idx] = dout_[idx] * Eigen::numext::polygamma(T(1), x_[idx]);
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
const MPType mp_dout = static_cast<MPType>(dout_[idx]);
const MPType mp_x = static_cast<MPType>(x_[idx]);
output_[idx] =
static_cast<T>(mp_dout * Eigen::numext::polygamma(MPType(1), mp_x));
}

private:
Expand Down
5 changes: 4 additions & 1 deletion paddle/phi/kernels/impl/digamma_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <unsupported/Eigen/SpecialFunctions>

#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/for_range.h"

Expand All @@ -27,7 +28,9 @@ struct DigammaFunctor {
: input_(input), output_(output), numel_(numel) {}

HOSTDEVICE void operator()(int64_t idx) const {
output_[idx] = Eigen::numext::digamma(input_[idx]);
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
const MPType mp_input = static_cast<MPType>(input_[idx]);
output_[idx] = static_cast<T>(Eigen::numext::digamma(mp_input));
}

private:
Expand Down
41 changes: 26 additions & 15 deletions paddle/phi/kernels/impl/dirichlet_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <cmath>
#include <random>
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/kernels/dirichlet_kernel.h"

// ROCM hcc doesn't work well with using std:: in kernel functions
Expand Down Expand Up @@ -47,7 +48,10 @@ template <typename ScalarT, typename SamplerT>
struct BaseSampler {
SamplerT sampler_;
HOSTDEVICE BaseSampler(const SamplerT& sampler) : sampler_(sampler) {}
HOSTDEVICE ScalarT sample() { return sampler_(); }
HOSTDEVICE ScalarT sample() {
// Sometimes convert float to float16/bfloat16
return static_cast<ScalarT>(sampler_());
}
};

// `sample_gamma` is d from Numpy's distributions.c, and add support for
Expand Down Expand Up @@ -83,33 +87,40 @@ HOSTDEVICE ScalarT
sample_gamma(ScalarT alpha,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

外部调用这里的时候,是不是也需要把模版参数修改下?

auto sample =
        sample_gamma<T, T, decltype(uniform_lambda), decltype(normal_lambda)>(
            alpha_[index], standard_uniform, standard_normal);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没明白这里是什么问题,调用sample_gamma 参数 alpha, gamma都是T类型,sample_gamma定义是ScalarT,AccscalarT两种类型,可以调用

BaseSampler<AccscalarT, UniformSamplerT> standard_uniform,
BaseSampler<AccscalarT, NormalSamplerT> standard_normal) {
AccscalarT scale = 1.0f;
using MPTypeScalar = typename phi::dtype::MPTypeTrait<ScalarT>::Type;
using MPTypeAccscalar = typename phi::dtype::MPTypeTrait<AccscalarT>::Type;

MPTypeAccscalar mp_scale = static_cast<MPTypeAccscalar>(1.0f);
MPTypeScalar mp_alpha = static_cast<MPTypeScalar>(alpha);

// Boost alpha for higher acceptance probability.
if (alpha < 1.0f) {
if (alpha == 0.f) return 0.f;
scale *= COMPAT_POW(1 - standard_uniform.sample(), 1.0f / alpha);
alpha += 1.0f;
if (mp_alpha < 1.0f) {
if (mp_alpha == 0.f) return static_cast<ScalarT>(0.f);
MPTypeAccscalar mp_sample =
static_cast<MPTypeAccscalar>(standard_uniform.sample());
mp_scale *= COMPAT_POW(1 - mp_sample, 1.0f / mp_alpha);
mp_alpha += 1.0f;
}

// This implements the acceptance-rejection method of Marsaglia and Tsang
// (2000)
// doi:10.1145/358407.358414
const AccscalarT d = alpha - 1.0f / 3.0f;
const AccscalarT c = 1.0f / COMPAT_SQRT(9.0f * d);
const MPTypeAccscalar d = mp_alpha - 1.0f / 3.0f;
const MPTypeAccscalar c = 1.0f / COMPAT_SQRT(9.0f * d);
for (;;) {
AccscalarT x, y;
MPTypeAccscalar x, y;
do {
x = standard_normal.sample();
x = static_cast<MPTypeAccscalar>(standard_normal.sample());
y = 1.0f + c * x;
} while (y <= 0);
const AccscalarT v = y * y * y;
const AccscalarT u = 1 - standard_uniform.sample();
const AccscalarT xx = x * x;
const MPTypeAccscalar v = y * y * y;
const MPTypeAccscalar u =
1 - static_cast<MPTypeAccscalar>(standard_uniform.sample());
const MPTypeAccscalar xx = x * x;
if (u < 1.0f - 0.0331f * xx * xx)
return static_cast<ScalarT>(scale * d * v);
return static_cast<ScalarT>(mp_scale * d * v);
if (COMPAT_LOG(u) < 0.5f * xx + d * (1.0f - v + COMPAT_LOG(v)))
return static_cast<ScalarT>(scale * d * v);
return static_cast<ScalarT>(mp_scale * d * v);
}
}

Expand Down
5 changes: 4 additions & 1 deletion python/paddle/distribution/dirichlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,10 @@ def _dirichlet(concentration, name=None):
else:
op_type = 'dirichlet'
check_variable_and_dtype(
concentration, 'concentration', ['float32', 'float64'], op_type
concentration,
'concentration',
['float16', 'float32', 'float64', 'uint16'],
op_type,
)
helper = LayerHelper(op_type, **locals())
out = helper.create_variable_for_type_inference(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,15 @@
import paddle

sys.path.append("../")
from eager_op_test import OpTest
import unittest

from eager_op_test import (
OpTest,
convert_float_to_uint16,
convert_uint16_to_float,
)

from paddle.fluid import core

paddle.enable_static()

Expand Down Expand Up @@ -52,3 +60,89 @@ def _hypothesis_testing(self, outs):
)[0],
0.01,
)


class TestDirichletFP16Op(OpTest):
# Because dirichlet random sample have not gradient, we skip gradient check.
no_need_check_grad = True

def setUp(self):
self.op_type = "dirichlet"
self.alpha = np.array((1.0, 2.0))
self.sample_shape = (100000, 2)
self.dtype = np.float16

self.inputs = {
'Alpha': np.broadcast_to(self.alpha, self.sample_shape).astype(
self.dtype
)
}
self.attrs = {}
self.outputs = {'Out': np.zeros(self.sample_shape).astype(self.dtype)}

def test_check_output(self):
self.check_output_customized(self._hypothesis_testing)

def _hypothesis_testing(self, outs):
self.assertEqual(outs[0].shape, self.sample_shape)
self.assertTrue(np.all(outs[0] > 0.0))
self.assertLess(
scipy.stats.kstest(
outs[0][:, 0],
# scipy dirichlet have not cdf, use beta to replace it.
scipy.stats.beta(a=self.alpha[0], b=self.alpha[1]).cdf,
)[0],
0.01,
)


@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestDirichletBF16Op(OpTest):
# Because dirichlet random sample have not gradient, we skip gradient check.
no_need_check_grad = True

def setUp(self):
self.op_type = "dirichlet"
self.alpha = np.array((1.0, 2.0))
self.sample_shape = (10000, 2)
self.dtype = np.uint16
self.np_dtype = np.float32

self.inputs = {
'Alpha': np.broadcast_to(self.alpha, self.sample_shape).astype(
self.np_dtype
)
}
self.attrs = {}
self.outputs = {
'Out': np.zeros(self.sample_shape).astype(self.np_dtype)
}
self.inputs['Alpha'] = convert_float_to_uint16(self.inputs['Alpha'])
self.outputs['Out'] = convert_float_to_uint16(self.outputs['Out'])
self.place = core.CUDAPlace(0)

def test_check_output(self):
self.check_output_with_place_customized(
self._hypothesis_testing, place=core.CUDAPlace(0)
)

def _hypothesis_testing(self, outs):
outs = convert_uint16_to_float(outs)
self.assertEqual(outs[0].shape, self.sample_shape)
self.assertTrue(np.all(outs[0] > 0.0))
self.assertLess(
scipy.stats.kstest(
outs[0][:, 0],
# scipy dirichlet have not cdf, use beta to replace it.
scipy.stats.beta(a=self.alpha[0], b=self.alpha[1]).cdf,
)[0],
0.3, # The bfloat16 test difference is below 0.3
)


if __name__ == '__main__':
unittest.main()
40 changes: 39 additions & 1 deletion python/paddle/fluid/tests/unittests/test_digamma_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
import unittest

import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16
from scipy.special import psi

import paddle
from paddle import fluid, static
from paddle.fluid import core


class TestDigammaOp(OpTest):
Expand Down Expand Up @@ -55,6 +56,43 @@ def test_check_grad_normal(self):
self.check_grad(['X'], 'Out')


class TestDigammaFP16Op(TestDigammaOp):
def init_dtype_type(self):
self.dtype = np.float16


@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support bfloat16",
)
class TestDigammaBF16Op(OpTest):
def setUp(self):
# switch to static
paddle.enable_static()

self.op_type = 'digamma'
self.python_api = paddle.digamma
self.init_dtype_type()
shape = (5, 32)
data = np.random.random(shape).astype(self.np_dtype) + 1
self.inputs = {'X': convert_float_to_uint16(data)}
result = np.ones(shape).astype(self.np_dtype)
result = psi(data)
self.outputs = {'Out': convert_float_to_uint16(result)}

def init_dtype_type(self):
self.dtype = np.uint16
self.np_dtype = np.float32

def test_check_output(self):
# bfloat16 needs to set the parameter place
self.check_output_with_place(core.CUDAPlace(0))

def test_check_grad_normal(self):
self.check_grad_with_place(core.CUDAPlace(0), ['X'], 'Out')


class TestDigammaAPI(unittest.TestCase):
def setUp(self):
# switch to static
Expand Down
4 changes: 3 additions & 1 deletion python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -4011,7 +4011,9 @@ def digamma(x, name=None):
if in_dygraph_mode():
return _C_ops.digamma(x)
else:
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'digamma')
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64', 'uint16'], 'digamma'
)
helper = LayerHelper('digamma', **locals())
out = helper.create_variable_for_type_inference(x.dtype)
helper.append_op(type='digamma', inputs={'X': x}, outputs={'Out': out})
Expand Down